{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Massiv.Array.Ops.Map
( map
, imap
, traverseA
, traverseA_
, itraverseA
, itraverseA_
, traverseAR
, itraverseAR
, sequenceA
, sequenceA_
, traverseS
, traversePrim
, itraversePrim
, traversePrimR
, itraversePrimR
, mapM
, mapMR
, forM
, forMR
, imapM
, imapMR
, iforM
, iforMR
, mapM_
, forM_
, imapM_
, iforM_
, mapIO
, mapWS
, mapIO_
, imapIO
, imapWS
, imapIO_
, forIO
, forWS
, forIO_
, iforIO
, iforWS
, iforIO_
, imapSchedulerM_
, iforSchedulerM_
, zip
, zip3
, unzip
, unzip3
, zipWith
, zipWith3
, izipWith
, izipWith3
, liftArray2
, zipWithA
, izipWithA
, zipWith3A
, izipWith3A
) where
import Control.Monad (void)
import Control.Monad.Primitive (PrimMonad)
import Control.Scheduler
import Data.Coerce
import Data.Massiv.Array.Delayed.Pull
import Data.Massiv.Array.Delayed.Stream
import Data.Massiv.Array.Mutable
import Data.Massiv.Array.Ops.Construct (makeArrayA, makeArrayLinearA)
import Data.Massiv.Core.Common
import Data.Massiv.Core.Index.Internal (Sz(..))
import Prelude hiding (map, mapM, mapM_, sequenceA, traverse, unzip, unzip3,
zip, zip3, zipWith, zipWith3)
map :: Source r ix e' => (e' -> e) -> Array r ix e' -> Array D ix e
map f = imap (const f)
{-# INLINE map #-}
imap :: Source r ix e' => (ix -> e' -> e) -> Array r ix e' -> Array D ix e
imap f !arr = DArray (getComp arr) (size arr) (\ !ix -> f ix (unsafeIndex arr ix))
{-# INLINE imap #-}
zip :: (Source r1 ix e1, Source r2 ix e2)
=> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix (e1, e2)
zip = zipWith (,)
{-# INLINE zip #-}
zip3 :: (Source r1 ix e1, Source r2 ix e2, Source r3 ix e3)
=> Array r1 ix e1 -> Array r2 ix e2 -> Array r3 ix e3 -> Array D ix (e1, e2, e3)
zip3 = zipWith3 (,,)
{-# INLINE zip3 #-}
unzip :: Source r ix (e1, e2) => Array r ix (e1, e2) -> (Array D ix e1, Array D ix e2)
unzip arr = (map fst arr, map snd arr)
{-# INLINE unzip #-}
unzip3 :: Source r ix (e1, e2, e3)
=> Array r ix (e1, e2, e3) -> (Array D ix e1, Array D ix e2, Array D ix e3)
unzip3 arr = (map (\ (e, _, _) -> e) arr, map (\ (_, e, _) -> e) arr, map (\ (_, _, e) -> e) arr)
{-# INLINE unzip3 #-}
zipWith :: (Source r1 ix e1, Source r2 ix e2)
=> (e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
zipWith f = izipWith (\ _ e1 e2 -> f e1 e2)
{-# INLINE zipWith #-}
izipWith :: (Source r1 ix e1, Source r2 ix e2)
=> (ix -> e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
izipWith f arr1 arr2 =
DArray
(getComp arr1 <> getComp arr2)
(SafeSz (liftIndex2 min (coerce (size arr1)) (coerce (size arr2)))) $ \ !ix ->
f ix (unsafeIndex arr1 ix) (unsafeIndex arr2 ix)
{-# INLINE izipWith #-}
zipWith3 :: (Source r1 ix e1, Source r2 ix e2, Source r3 ix e3)
=> (e1 -> e2 -> e3 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array r3 ix e3 -> Array D ix e
zipWith3 f = izipWith3 (\ _ e1 e2 e3 -> f e1 e2 e3)
{-# INLINE zipWith3 #-}
izipWith3
:: (Source r1 ix e1, Source r2 ix e2, Source r3 ix e3)
=> (ix -> e1 -> e2 -> e3 -> e)
-> Array r1 ix e1
-> Array r2 ix e2
-> Array r3 ix e3
-> Array D ix e
izipWith3 f arr1 arr2 arr3 =
DArray
(getComp arr1 <> getComp arr2 <> getComp arr3)
(SafeSz
(liftIndex2
min
(liftIndex2 min (coerce (size arr1)) (coerce (size arr2)))
(coerce (size arr3)))) $ \ !ix ->
f ix (unsafeIndex arr1 ix) (unsafeIndex arr2 ix) (unsafeIndex arr3 ix)
{-# INLINE izipWith3 #-}
zipWithA ::
(Source r1 ix e1, Source r2 ix e2, Applicative f, Mutable r ix e)
=> (e1 -> e2 -> f e)
-> Array r1 ix e1
-> Array r2 ix e2
-> f (Array r ix e)
zipWithA f = izipWithA (const f)
{-# INLINE zipWithA #-}
izipWithA ::
(Source r1 ix e1, Source r2 ix e2, Applicative f, Mutable r ix e)
=> (ix -> e1 -> e2 -> f e)
-> Array r1 ix e1
-> Array r2 ix e2
-> f (Array r ix e)
izipWithA f arr1 arr2 =
setComp (getComp arr1 <> getComp arr2) <$>
makeArrayA
(SafeSz (liftIndex2 min (coerce (size arr1)) (coerce (size arr2))))
(\ !ix -> f ix (unsafeIndex arr1 ix) (unsafeIndex arr2 ix))
{-# INLINE izipWithA #-}
zipWith3A ::
(Source r1 ix e1, Source r2 ix e2, Source r3 ix e3, Applicative f, Mutable r ix e)
=> (e1 -> e2 -> e3 -> f e)
-> Array r1 ix e1
-> Array r2 ix e2
-> Array r3 ix e3
-> f (Array r ix e)
zipWith3A f = izipWith3A (const f)
{-# INLINE zipWith3A #-}
izipWith3A ::
(Source r1 ix e1, Source r2 ix e2, Source r3 ix e3, Applicative f, Mutable r ix e)
=> (ix -> e1 -> e2 -> e3 -> f e)
-> Array r1 ix e1
-> Array r2 ix e2
-> Array r3 ix e3
-> f (Array r ix e)
izipWith3A f arr1 arr2 arr3 =
setComp (getComp arr1 <> getComp arr2 <> getComp arr3) <$>
makeArrayA sz (\ !ix -> f ix (unsafeIndex arr1 ix) (unsafeIndex arr2 ix) (unsafeIndex arr3 ix))
where
sz =
SafeSz $
liftIndex2 min (liftIndex2 min (coerce (size arr1)) (coerce (size arr2))) (coerce (size arr3))
{-# INLINE izipWith3A #-}
liftArray2
:: (Source r1 ix a, Source r2 ix b)
=> (a -> b -> e) -> Array r1 ix a -> Array r2 ix b -> Array D ix e
liftArray2 f !arr1 !arr2
| sz1 == oneSz = map (f (unsafeIndex arr1 zeroIndex)) arr2
| sz2 == oneSz = map (`f` unsafeIndex arr2 zeroIndex) arr1
| sz1 == sz2 =
DArray (getComp arr1 <> getComp arr2) sz1 (\ !ix -> f (unsafeIndex arr1 ix) (unsafeIndex arr2 ix))
| otherwise = throw $ SizeMismatchException (size arr1) (size arr2)
where
sz1 = size arr1
sz2 = size arr2
{-# INLINE liftArray2 #-}
traverseA ::
forall r ix e r' a f . (Source r' ix a, Mutable r ix e, Applicative f)
=> (a -> f e)
-> Array r' ix a
-> f (Array r ix e)
traverseA f arr = makeArrayLinearA (size arr) (f . unsafeLinearIndex arr)
{-# INLINE traverseA #-}
traverseA_ :: forall r ix e a f . (Source r ix e, Applicative f) => (e -> f a) -> Array r ix e -> f ()
traverseA_ f arr = loopA_ 0 (< totalElem (size arr)) (+ 1) (f . unsafeLinearIndex arr)
{-# INLINE traverseA_ #-}
sequenceA ::
forall r ix e r' f. (Source r' ix (f e), Mutable r ix e, Applicative f)
=> Array r' ix (f e)
-> f (Array r ix e)
sequenceA = traverseA id
{-# INLINE sequenceA #-}
sequenceA_ :: forall r ix e f . (Source r ix (f e), Applicative f) => Array r ix (f e) -> f ()
sequenceA_ = traverseA_ id
{-# INLINE sequenceA_ #-}
itraverseA ::
forall r ix e r' a f . (Source r' ix a, Mutable r ix e, Applicative f)
=> (ix -> a -> f e)
-> Array r' ix a
-> f (Array r ix e)
itraverseA f arr =
setComp (getComp arr) <$> makeArrayA (size arr) (\ !ix -> f ix (unsafeIndex arr ix))
{-# INLINE itraverseA #-}
itraverseA_ ::
forall r ix e a f. (Source r ix a, Applicative f)
=> (ix -> a -> f e)
-> Array r ix a
-> f ()
itraverseA_ f arr =
loopA_ 0 (< totalElem sz) (+ 1) (\ !i -> f (fromLinearIndex sz i) (unsafeLinearIndex arr i))
where
sz = size arr
{-# INLINE itraverseA_ #-}
traverseAR ::
(Source r' ix a, Mutable r ix b, Applicative f)
=> r
-> (a -> f b)
-> Array r' ix a
-> f (Array r ix b)
traverseAR _ = traverseA
{-# INLINE traverseAR #-}
{-# DEPRECATED traverseAR "In favor of `traverseA`" #-}
itraverseAR ::
(Source r' ix a, Mutable r ix b, Applicative f)
=> r
-> (ix -> a -> f b)
-> Array r' ix a
-> f (Array r ix b)
itraverseAR _ = itraverseA
{-# INLINE itraverseAR #-}
{-# DEPRECATED itraverseAR "In favor of `itraverseA`" #-}
traversePrim ::
forall r ix b r' a m . (Source r' ix a, Mutable r ix b, PrimMonad m)
=> (a -> m b)
-> Array r' ix a
-> m (Array r ix b)
traversePrim f = itraversePrim (const f)
{-# INLINE traversePrim #-}
itraversePrim ::
forall r ix b r' a m . (Source r' ix a, Mutable r ix b, PrimMonad m)
=> (ix -> a -> m b)
-> Array r' ix a
-> m (Array r ix b)
itraversePrim f arr =
setComp (getComp arr) <$>
generateArrayLinearS
(size arr)
(\ !i ->
let ix = fromLinearIndex (size arr) i
in f ix (unsafeLinearIndex arr i))
{-# INLINE itraversePrim #-}
traversePrimR ::
(Source r' ix a, Mutable r ix b, PrimMonad m)
=> r
-> (a -> m b)
-> Array r' ix a
-> m (Array r ix b)
traversePrimR _ = traversePrim
{-# INLINE traversePrimR #-}
{-# DEPRECATED traversePrimR "In favor of `traversePrim`" #-}
itraversePrimR ::
(Source r' ix a, Mutable r ix b, PrimMonad m)
=> r
-> (ix -> a -> m b)
-> Array r' ix a
-> m (Array r ix b)
itraversePrimR _ = itraversePrim
{-# INLINE itraversePrimR #-}
{-# DEPRECATED itraversePrimR "In favor of `itraversePrim`" #-}
mapM ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, Monad m)
=> (a -> m b)
-> Array r' ix a
-> m (Array r ix b)
mapM = traverseA
{-# INLINE mapM #-}
mapMR ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, Monad m)
=> r
-> (a -> m b)
-> Array r' ix a
-> m (Array r ix b)
mapMR _ = traverseA
{-# INLINE mapMR #-}
forM ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, Monad m)
=> Array r' ix a
-> (a -> m b)
-> m (Array r ix b)
forM = flip traverseA
{-# INLINE forM #-}
forMR ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, Monad m)
=> r
-> Array r' ix a
-> (a -> m b)
-> m (Array r ix b)
forMR _ = flip traverseA
{-# INLINE forMR #-}
imapM ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, Monad m)
=> (ix -> a -> m b)
-> Array r' ix a
-> m (Array r ix b)
imapM = itraverseA
{-# INLINE imapM #-}
imapMR ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, Monad m)
=> r
-> (ix -> a -> m b)
-> Array r' ix a
-> m (Array r ix b)
imapMR _ = itraverseA
{-# INLINE imapMR #-}
iforM ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, Monad m)
=> (ix -> a -> m b)
-> Array r' ix a
-> m (Array r ix b)
iforM = itraverseA
{-# INLINE iforM #-}
iforMR ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, Monad m)
=> r
-> (ix -> a -> m b)
-> Array r' ix a
-> m (Array r ix b)
iforMR _ = itraverseA
{-# INLINE iforMR #-}
mapM_ :: (Source r ix a, Monad m) => (a -> m b) -> Array r ix a -> m ()
mapM_ f !arr = iterM_ zeroIndex (unSz (size arr)) (pureIndex 1) (<) (f . unsafeIndex arr)
{-# INLINE mapM_ #-}
forM_ :: (Source r ix a, Monad m) => Array r ix a -> (a -> m b) -> m ()
forM_ = flip mapM_
{-# INLINE forM_ #-}
iforM_ :: (Source r ix a, Monad m) => Array r ix a -> (ix -> a -> m b) -> m ()
iforM_ = flip imapM_
{-# INLINE iforM_ #-}
mapIO ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, MonadUnliftIO m, PrimMonad m)
=> (a -> m b)
-> Array r' ix a
-> m (Array r ix b)
mapIO action = imapIO (const action)
{-# INLINE mapIO #-}
mapIO_ :: (Source r b e, MonadUnliftIO m) => (e -> m a) -> Array r b e -> m ()
mapIO_ action = imapIO_ (const action)
{-# INLINE mapIO_ #-}
imapIO_ :: (Source r ix e, MonadUnliftIO m) => (ix -> e -> m a) -> Array r ix e -> m ()
imapIO_ action arr =
withScheduler_ (getComp arr) $ \scheduler -> imapSchedulerM_ scheduler action arr
{-# INLINE imapIO_ #-}
imapSchedulerM_ ::
(Source r ix e, Monad m) => Scheduler m () -> (ix -> e -> m a) -> Array r ix e -> m ()
imapSchedulerM_ scheduler action arr = do
let sz = size arr
splitLinearlyWith_
scheduler
(totalElem sz)
(unsafeLinearIndex arr)
(\i -> void . action (fromLinearIndex sz i))
{-# INLINE imapSchedulerM_ #-}
iforSchedulerM_ ::
(Source r ix e, Monad m) => Scheduler m () -> Array r ix e -> (ix -> e -> m a) -> m ()
iforSchedulerM_ scheduler arr action = imapSchedulerM_ scheduler action arr
{-# INLINE iforSchedulerM_ #-}
imapIO ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, MonadUnliftIO m, PrimMonad m)
=> (ix -> a -> m b)
-> Array r' ix a
-> m (Array r ix b)
imapIO action arr = generateArray (getComp arr) (size arr) $ \ix -> action ix (unsafeIndex arr ix)
{-# INLINE imapIO #-}
forIO ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, MonadUnliftIO m, PrimMonad m)
=> Array r' ix a
-> (a -> m b)
-> m (Array r ix b)
forIO = flip mapIO
{-# INLINE forIO #-}
imapWS ::
forall r ix b r' a s m. (Source r' ix a, Mutable r ix b, MonadUnliftIO m, PrimMonad m)
=> WorkerStates s
-> (ix -> a -> s -> m b)
-> Array r' ix a
-> m (Array r ix b)
imapWS states f arr = generateArrayWS states (size arr) (\ix s -> f ix (unsafeIndex arr ix) s)
{-# INLINE imapWS #-}
mapWS ::
forall r ix b r' a s m. (Source r' ix a, Mutable r ix b, MonadUnliftIO m, PrimMonad m)
=> WorkerStates s
-> (a -> s -> m b)
-> Array r' ix a
-> m (Array r ix b)
mapWS states f = imapWS states (\ _ -> f)
{-# INLINE mapWS #-}
iforWS ::
forall r ix b r' a s m. (Source r' ix a, Mutable r ix b, MonadUnliftIO m, PrimMonad m)
=> WorkerStates s
-> Array r' ix a
-> (ix -> a -> s -> m b)
-> m (Array r ix b)
iforWS states f arr = imapWS states arr f
{-# INLINE iforWS #-}
forWS ::
forall r ix b r' a s m. (Source r' ix a, Mutable r ix b, MonadUnliftIO m, PrimMonad m)
=> WorkerStates s
-> Array r' ix a
-> (a -> s -> m b)
-> m (Array r ix b)
forWS states arr f = imapWS states (\ _ -> f) arr
{-# INLINE forWS #-}
forIO_ :: (Source r ix e, MonadUnliftIO m) => Array r ix e -> (e -> m a) -> m ()
forIO_ = flip mapIO_
{-# INLINE forIO_ #-}
iforIO ::
forall r ix b r' a m. (Source r' ix a, Mutable r ix b, MonadUnliftIO m, PrimMonad m)
=> Array r' ix a
-> (ix -> a -> m b)
-> m (Array r ix b)
iforIO = flip imapIO
{-# INLINE iforIO #-}
iforIO_ :: (Source r ix a, MonadUnliftIO m) => Array r ix a -> (ix -> a -> m b) -> m ()
iforIO_ = flip imapIO_
{-# INLINE iforIO_ #-}