{-# LANGUAGE TypeFamilies #-}
module Data.Array.Comfort.Boxed.Strict.Unchecked where

import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Boxed.Unchecked (Array(Array))

import qualified Data.Primitive.Array as Prim

import qualified Control.Monad.ST.Strict as STStrict
import qualified Control.Monad.Trans.Class as MT
import qualified Control.Monad.Trans.State as MS

import Prelude hiding (map, zipWith)


toList :: (Shape.C sh) => Array sh a -> [a]
toList :: forall sh a. C sh => Array sh a -> [a]
toList (Array sh
sh Array a
arr) =
   (forall s. ST s [a]) -> [a]
forall a. (forall s. ST s a) -> a
STStrict.runST ((Int -> ST s a) -> [Int] -> ST s [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Array a -> Int -> ST s a
forall (m :: * -> *) a. Applicative m => Array a -> Int -> m a
Prim.indexArrayM Array a
arr) ([Int] -> ST s [a]) -> [Int] -> ST s [a]
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take (sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh) [Int
0..])

map :: (Shape.C sh) => (a -> b) -> Array sh a -> Array sh b
map :: forall sh a b. C sh => (a -> b) -> Array sh a -> Array sh b
map a -> b
f (Array sh
sh Array a
arr) = sh -> Array b -> Array sh b
forall sh a. sh -> Array a -> Array sh a
Array sh
sh (Array b -> Array sh b) -> Array b -> Array sh b
forall a b. (a -> b) -> a -> b
$ (a -> b) -> Array a -> Array b
forall a b. (a -> b) -> Array a -> Array b
Prim.mapArray' a -> b
f Array a
arr

zipWith ::
   (Shape.C sh) => (a -> b -> c) -> Array sh a -> Array sh b -> Array sh c
zipWith :: forall sh a b c.
C sh =>
(a -> b -> c) -> Array sh a -> Array sh b -> Array sh c
zipWith a -> b -> c
f (Array sh
sha Array a
arra) (Array sh
_shb Array b
arrb) =
   sh -> Array c -> Array sh c
forall sh a. sh -> Array a -> Array sh a
Array sh
sha (Array c -> Array sh c) -> Array c -> Array sh c
forall a b. (a -> b) -> a -> b
$
   (forall s. ST s (Array c)) -> Array c
forall a. (forall s. ST s a) -> a
STStrict.runST
      ((StateT Int (ST s) (Array c) -> Int -> ST s (Array c))
-> Int -> StateT Int (ST s) (Array c) -> ST s (Array c)
forall a b c. (a -> b -> c) -> b -> a -> c
flip StateT Int (ST s) (Array c) -> Int -> ST s (Array c)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
MS.evalStateT Int
0 (StateT Int (ST s) (Array c) -> ST s (Array c))
-> StateT Int (ST s) (Array c) -> ST s (Array c)
forall a b. (a -> b) -> a -> b
$
       (a -> StateT Int (ST s) c)
-> Array a -> StateT Int (ST s) (Array c)
forall (m :: * -> *) a b.
PrimMonad m =>
(a -> m b) -> Array a -> m (Array b)
Prim.traverseArrayP
         (\a
a -> do
            Int
k <- StateT Int (ST s) Int
forall (m :: * -> *) s. Monad m => StateT s m s
MS.get
            b
b <- ST s b -> StateT Int (ST s) b
forall (m :: * -> *) a. Monad m => m a -> StateT Int m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
MT.lift (ST s b -> StateT Int (ST s) b) -> ST s b -> StateT Int (ST s) b
forall a b. (a -> b) -> a -> b
$ Array b -> Int -> ST s b
forall (m :: * -> *) a. Applicative m => Array a -> Int -> m a
Prim.indexArrayM Array b
arrb Int
k
            Int -> StateT Int (ST s) ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
MS.put (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
            c -> StateT Int (ST s) c
forall a. a -> StateT Int (ST s) a
forall (m :: * -> *) a. Monad m => a -> m a
return (c -> StateT Int (ST s) c) -> c -> StateT Int (ST s) c
forall a b. (a -> b) -> a -> b
$ a -> b -> c
f a
a b
b)
         Array a
arra)