{-# LANGUAGE TypeFamilies #-}
module Data.Array.Comfort.Boxed.Unchecked where

import qualified Data.Array.Comfort.Shape as Shape
import qualified Data.Primitive.Array as Prim

import qualified Control.Monad.ST.Strict as STStrict
import qualified Control.Monad.Trans.Class as MT
import qualified Control.Monad.Trans.State as MS
import Control.Monad (liftM)
import Control.Applicative (Applicative, pure, (<*>), (<$>))
import Control.DeepSeq (NFData, rnf)

import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold
import qualified Data.List as List
import Prelude hiding (map, zipWith, replicate)


data Array sh a =
   Array {
      Array sh a -> sh
shape :: sh,
      Array sh a -> Array a
buffer :: Prim.Array a
   }

instance (Shape.C sh, Show sh, Show a) => Show (Array sh a) where
   showsPrec :: Int -> Array sh a -> ShowS
showsPrec Int
p Array sh a
arr =
      Bool -> ShowS -> ShowS
showParen (Int
pInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
10) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$
         String -> ShowS
showString String
"BoxedArray.fromList " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
         Int -> sh -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 (Array sh a -> sh
forall sh a. Array sh a -> sh
shape Array sh a
arr) ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
         Char -> ShowS
showChar Char
' ' ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
         [a] -> ShowS
forall a. Show a => a -> ShowS
shows (Array sh a -> [a]
forall sh a. C sh => Array sh a -> [a]
toListLazy Array sh a
arr)


instance (Shape.C sh, NFData sh, NFData a) => NFData (Array sh a) where
   rnf :: Array sh a -> ()
rnf a :: Array sh a
a@(Array sh
sh Array a
_arr) = (sh, [a]) -> ()
forall a. NFData a => a -> ()
rnf (sh
sh, Array sh a -> [a]
forall sh a. C sh => Array sh a -> [a]
toListLazy Array sh a
a)

instance (Shape.C sh) => Functor (Array sh) where
   fmap :: (a -> b) -> Array sh a -> Array sh b
fmap = (a -> b) -> Array sh a -> Array sh b
forall sh a b. C sh => (a -> b) -> Array sh a -> Array sh b
map

{- |
We must restrict 'Applicative' to 'Shape.Static' because of 'pure'.
Because the shape is static, we do not need a size check in '(<*>)'.
-}
instance (Shape.Static sh) => Applicative (Array sh) where
   pure :: a -> Array sh a
pure = sh -> a -> Array sh a
forall sh a. C sh => sh -> a -> Array sh a
replicate sh
forall sh. Static sh => sh
Shape.static
   <*> :: Array sh (a -> b) -> Array sh a -> Array sh b
(<*>) = ((a -> b) -> a -> b)
-> Array sh (a -> b) -> Array sh a -> Array sh b
forall sh a b c.
C sh =>
(a -> b -> c) -> Array sh a -> Array sh b -> Array sh c
zipWith (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
($)

instance (Shape.C sh) => Fold.Foldable (Array sh) where
   fold :: Array sh m -> m
fold = Array m -> m
forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
Fold.fold (Array m -> m) -> (Array sh m -> Array m) -> Array sh m -> m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array sh m -> Array m
forall sh a. Array sh a -> Array a
buffer
   foldMap :: (a -> m) -> Array sh a -> m
foldMap a -> m
f = (a -> m) -> Array a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
Fold.foldMap a -> m
f (Array a -> m) -> (Array sh a -> Array a) -> Array sh a -> m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array sh a -> Array a
forall sh a. Array sh a -> Array a
buffer
   foldl :: (b -> a -> b) -> b -> Array sh a -> b
foldl b -> a -> b
f b
a = (b -> a -> b) -> b -> Array a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Fold.foldl b -> a -> b
f b
a (Array a -> b) -> (Array sh a -> Array a) -> Array sh a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array sh a -> Array a
forall sh a. Array sh a -> Array a
buffer
   foldr :: (a -> b -> b) -> b -> Array sh a -> b
foldr a -> b -> b
f b
a = (a -> b -> b) -> b -> Array a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
Fold.foldr a -> b -> b
f b
a (Array a -> b) -> (Array sh a -> Array a) -> Array sh a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array sh a -> Array a
forall sh a. Array sh a -> Array a
buffer
   foldl1 :: (a -> a -> a) -> Array sh a -> a
foldl1 a -> a -> a
f = (a -> a -> a) -> Array a -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
Fold.foldl1 a -> a -> a
f (Array a -> a) -> (Array sh a -> Array a) -> Array sh a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array sh a -> Array a
forall sh a. Array sh a -> Array a
buffer
   foldr1 :: (a -> a -> a) -> Array sh a -> a
foldr1 a -> a -> a
f = (a -> a -> a) -> Array a -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
Fold.foldr1 a -> a -> a
f (Array a -> a) -> (Array sh a -> Array a) -> Array sh a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array sh a -> Array a
forall sh a. Array sh a -> Array a
buffer

instance (Shape.C sh) => Trav.Traversable (Array sh) where
   traverse :: (a -> f b) -> Array sh a -> f (Array sh b)
traverse a -> f b
f (Array sh
sh Array a
arr) = sh -> Array b -> Array sh b
forall sh a. sh -> Array a -> Array sh a
Array sh
sh (Array b -> Array sh b) -> f (Array b) -> f (Array sh b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> f b) -> Array a -> f (Array b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
Trav.traverse a -> f b
f Array a
arr
   sequenceA :: Array sh (f a) -> f (Array sh a)
sequenceA (Array sh
sh Array (f a)
arr) = sh -> Array a -> Array sh a
forall sh a. sh -> Array a -> Array sh a
Array sh
sh (Array a -> Array sh a) -> f (Array a) -> f (Array sh a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Array (f a) -> f (Array a)
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
Trav.sequenceA Array (f a)
arr
   mapM :: (a -> m b) -> Array sh a -> m (Array sh b)
mapM a -> m b
f (Array sh
sh Array a
arr) = (Array b -> Array sh b) -> m (Array b) -> m (Array sh b)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (sh -> Array b -> Array sh b
forall sh a. sh -> Array a -> Array sh a
Array sh
sh) (m (Array b) -> m (Array sh b)) -> m (Array b) -> m (Array sh b)
forall a b. (a -> b) -> a -> b
$ (a -> m b) -> Array a -> m (Array b)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
Trav.mapM a -> m b
f Array a
arr
   sequence :: Array sh (m a) -> m (Array sh a)
sequence (Array sh
sh Array (m a)
arr) = (Array a -> Array sh a) -> m (Array a) -> m (Array sh a)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (sh -> Array a -> Array sh a
forall sh a. sh -> Array a -> Array sh a
Array sh
sh) (m (Array a) -> m (Array sh a)) -> m (Array a) -> m (Array sh a)
forall a b. (a -> b) -> a -> b
$ Array (m a) -> m (Array a)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
Trav.sequence Array (m a)
arr


-- add assertion, at least in an exposed version
reshape :: sh1 -> Array sh0 a -> Array sh1 a
reshape :: sh1 -> Array sh0 a -> Array sh1 a
reshape sh1
sh (Array sh0
_ Array a
arr) = sh1 -> Array a -> Array sh1 a
forall sh a. sh -> Array a -> Array sh a
Array sh1
sh Array a
arr

mapShape :: (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
mapShape :: (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
mapShape sh0 -> sh1
f (Array sh0
sh Array a
arr) = sh1 -> Array a -> Array sh1 a
forall sh a. sh -> Array a -> Array sh a
Array (sh0 -> sh1
f sh0
sh) Array a
arr


infixl 9 !

(!) :: (Shape.Indexed sh) => Array sh a -> Shape.Index sh -> a
(!) (Array sh
sh Array a
arr) Index sh
ix = Array a -> Int -> a
forall a. Array a -> Int -> a
Prim.indexArray Array a
arr (Int -> a) -> Int -> a
forall a b. (a -> b) -> a -> b
$ sh -> Index sh -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.uncheckedOffset sh
sh Index sh
ix

toListLazy :: (Shape.C sh) => Array sh a -> [a]
toListLazy :: Array sh a -> [a]
toListLazy (Array sh
sh Array a
arr) =
   (Int -> a) -> [Int] -> [a]
forall a b. (a -> b) -> [a] -> [b]
List.map (Array a -> Int -> a
forall a. Array a -> Int -> a
Prim.indexArray Array a
arr) ([Int] -> [a]) -> [Int] -> [a]
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take (sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh) [Int
0..]

toList :: (Shape.C sh) => Array sh a -> [a]
toList :: Array sh a -> [a]
toList (Array sh
sh Array a
arr) =
   (forall s. ST s [a]) -> [a]
forall a. (forall s. ST s a) -> a
STStrict.runST ((Int -> ST s a) -> [Int] -> ST s [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Array a -> Int -> ST s a
forall (m :: * -> *) a. Monad m => Array a -> Int -> m a
Prim.indexArrayM Array a
arr) ([Int] -> ST s [a]) -> [Int] -> ST s [a]
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take (sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh) [Int
0..])

fromList :: (Shape.C sh) => sh -> [a] -> Array sh a
fromList :: sh -> [a] -> Array sh a
fromList sh
sh [a]
xs = sh -> Array a -> Array sh a
forall sh a. sh -> Array a -> Array sh a
Array sh
sh (Array a -> Array sh a) -> Array a -> Array sh a
forall a b. (a -> b) -> a -> b
$ Int -> [Item (Array a)] -> Array a
forall l. IsList l => Int -> [Item l] -> l
Prim.fromListN (sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh) [a]
[Item (Array a)]
xs

vectorFromList :: [a] -> Array (Shape.ZeroBased Int) a
vectorFromList :: [a] -> Array (ZeroBased Int) a
vectorFromList [a]
xs =
   let arr :: Array a
arr = [Item (Array a)] -> Array a
forall l. IsList l => [Item l] -> l
Prim.fromList [a]
[Item (Array a)]
xs
   in ZeroBased Int -> Array a -> Array (ZeroBased Int) a
forall sh a. sh -> Array a -> Array sh a
Array (Int -> ZeroBased Int
forall n. n -> ZeroBased n
Shape.ZeroBased (Int -> ZeroBased Int) -> Int -> ZeroBased Int
forall a b. (a -> b) -> a -> b
$ Array a -> Int
forall a. Array a -> Int
Prim.sizeofArray Array a
arr) Array a
arr

replicate :: (Shape.C sh) => sh -> a -> Array sh a
replicate :: sh -> a -> Array sh a
replicate sh
sh a
a =
   sh -> Array a -> Array sh a
forall sh a. sh -> Array a -> Array sh a
Array sh
sh (Array a -> Array sh a) -> Array a -> Array sh a
forall a b. (a -> b) -> a -> b
$
   (forall s. ST s (Array a)) -> Array a
forall a. (forall s. ST s a) -> a
STStrict.runST (MutableArray s a -> ST s (Array a)
forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> m (Array a)
Prim.unsafeFreezeArray  (MutableArray s a -> ST s (Array a))
-> ST s (MutableArray s a) -> ST s (Array a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Int -> a -> ST s (MutableArray (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MutableArray (PrimState m) a)
Prim.newArray (sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh) a
a)

map :: (Shape.C sh) => (a -> b) -> Array sh a -> Array sh b
map :: (a -> b) -> Array sh a -> Array sh b
map a -> b
f (Array sh
sh Array a
arr) = sh -> Array b -> Array sh b
forall sh a. sh -> Array a -> Array sh a
Array sh
sh (Array b -> Array sh b) -> Array b -> Array sh b
forall a b. (a -> b) -> a -> b
$ (a -> b) -> Array a -> Array b
forall a b. (a -> b) -> Array a -> Array b
Prim.mapArray' a -> b
f Array a
arr

zipWith ::
   (Shape.C sh) => (a -> b -> c) -> Array sh a -> Array sh b -> Array sh c
zipWith :: (a -> b -> c) -> Array sh a -> Array sh b -> Array sh c
zipWith a -> b -> c
f (Array sh
sha Array a
arra) (Array sh
_shb Array b
arrb) =
   sh -> Array c -> Array sh c
forall sh a. sh -> Array a -> Array sh a
Array sh
sha (Array c -> Array sh c) -> Array c -> Array sh c
forall a b. (a -> b) -> a -> b
$
   (forall s. ST s (Array c)) -> Array c
forall a. (forall s. ST s a) -> a
STStrict.runST
      ((StateT Int (ST s) (Array c) -> Int -> ST s (Array c))
-> Int -> StateT Int (ST s) (Array c) -> ST s (Array c)
forall a b c. (a -> b -> c) -> b -> a -> c
flip StateT Int (ST s) (Array c) -> Int -> ST s (Array c)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
MS.evalStateT Int
0 (StateT Int (ST s) (Array c) -> ST s (Array c))
-> StateT Int (ST s) (Array c) -> ST s (Array c)
forall a b. (a -> b) -> a -> b
$
       (a -> StateT Int (ST s) c)
-> Array a -> StateT Int (ST s) (Array c)
forall (m :: * -> *) a b.
PrimMonad m =>
(a -> m b) -> Array a -> m (Array b)
Prim.traverseArrayP
         (\a
a -> do
            Int
k <- StateT Int (ST s) Int
forall (m :: * -> *) s. Monad m => StateT s m s
MS.get
            b
b <- ST s b -> StateT Int (ST s) b
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
MT.lift (ST s b -> StateT Int (ST s) b) -> ST s b -> StateT Int (ST s) b
forall a b. (a -> b) -> a -> b
$ Array b -> Int -> ST s b
forall (m :: * -> *) a. Monad m => Array a -> Int -> m a
Prim.indexArrayM Array b
arrb Int
k
            Int -> StateT Int (ST s) ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
MS.put (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
            c -> StateT Int (ST s) c
forall (m :: * -> *) a. Monad m => a -> m a
return (c -> StateT Int (ST s) c) -> c -> StateT Int (ST s) c
forall a b. (a -> b) -> a -> b
$ a -> b -> c
f a
a b
b)
         Array a
arra)