{-# LANGUAGE TypeOperators, ExplicitForAll, FlexibleContexts #-}

module Data.Array.Repa.Operators.IndexSpace
        ( reshape
        , append, (++)
        , transpose
        , extract
        , backpermute,         unsafeBackpermute
        , backpermuteDft,      unsafeBackpermuteDft
        , extend,              unsafeExtend 
        , slice,               unsafeSlice)
where
import Data.Array.Repa.Index
import Data.Array.Repa.Slice
import Data.Array.Repa.Base
import Data.Array.Repa.Repr.Delayed
import Data.Array.Repa.Operators.Traversal
import Data.Array.Repa.Shape            as S
import Prelude                          hiding ((++), traverse)
import qualified Prelude                as P 

stage :: [Char]
stage   = [Char]
"Data.Array.Repa.Operators.IndexSpace"

-- Index space transformations ------------------------------------------------
-- | Impose a new shape on the elements of an array.
--   The new extent must be the same size as the original, else `error`.
reshape :: ( Shape sh1, Shape sh2
           , Source r1 e)
        => sh2
        -> Array r1 sh1 e
        -> Array D  sh2 e

reshape :: sh2 -> Array r1 sh1 e -> Array D sh2 e
reshape sh2
sh2 Array r1 sh1 e
arr
        | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ sh2 -> Int
forall sh. Shape sh => sh -> Int
S.size sh2
sh2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== sh1 -> Int
forall sh. Shape sh => sh -> Int
S.size (Array r1 sh1 e -> sh1
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r1 sh1 e
arr)
        = [Char] -> Array D sh2 e
forall a. HasCallStack => [Char] -> a
error 
        ([Char] -> Array D sh2 e) -> [Char] -> Array D sh2 e
forall a b. (a -> b) -> a -> b
$ [Char]
stage [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
P.++ [Char]
".reshape: reshaped array will not match size of the original"

reshape sh2
sh2 Array r1 sh1 e
arr
        = sh2 -> (sh2 -> e) -> Array D sh2 e
forall sh a. sh -> (sh -> a) -> Array D sh a
fromFunction sh2
sh2 
        ((sh2 -> e) -> Array D sh2 e) -> (sh2 -> e) -> Array D sh2 e
forall a b. (a -> b) -> a -> b
$ Array r1 sh1 e -> sh1 -> e
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
unsafeIndex Array r1 sh1 e
arr (sh1 -> e) -> (sh2 -> sh1) -> sh2 -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. sh1 -> Int -> sh1
forall sh. Shape sh => sh -> Int -> sh
fromIndex (Array r1 sh1 e -> sh1
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r1 sh1 e
arr) (Int -> sh1) -> (sh2 -> Int) -> sh2 -> sh1
forall b c a. (b -> c) -> (a -> b) -> a -> c
. sh2 -> sh2 -> Int
forall sh. Shape sh => sh -> sh -> Int
toIndex sh2
sh2
{-# INLINE [2] reshape #-}
 

-- | Append two arrays.
append, (++)
        :: ( Shape sh
           , Source r1 e, Source r2 e)
        => Array r1 (sh :. Int) e
        -> Array r2 (sh :. Int) e
        -> Array D  (sh :. Int) e

append :: Array r1 (sh :. Int) e
-> Array r2 (sh :. Int) e -> Array D (sh :. Int) e
append Array r1 (sh :. Int) e
arr1 Array r2 (sh :. Int) e
arr2
 = Array r1 (sh :. Int) e
-> Array r2 (sh :. Int) e
-> ((sh :. Int) -> (sh :. Int) -> sh :. Int)
-> (((sh :. Int) -> e) -> ((sh :. Int) -> e) -> (sh :. Int) -> e)
-> Array D (sh :. Int) e
forall r1 r2 sh sh' sh'' a b c.
(Source r1 a, Source r2 b, Shape sh, Shape sh') =>
Array r1 sh a
-> Array r2 sh' b
-> (sh -> sh' -> sh'')
-> ((sh -> a) -> (sh' -> b) -> sh'' -> c)
-> Array D sh'' c
unsafeTraverse2 Array r1 (sh :. Int) e
arr1 Array r2 (sh :. Int) e
arr2 (sh :. Int) -> (sh :. Int) -> sh :. Int
forall tail head.
(Shape tail, Num head) =>
(tail :. head) -> (tail :. head) -> tail :. head
fnExtent ((sh :. Int) -> e) -> ((sh :. Int) -> e) -> (sh :. Int) -> e
fnElem
 where
        (sh
_ :. Int
n)        = Array r1 (sh :. Int) e -> sh :. Int
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r1 (sh :. Int) e
arr1

        fnExtent :: (tail :. head) -> (tail :. head) -> tail :. head
fnExtent (tail
sh1 :. head
i) (tail
sh2  :. head
j)
                = tail -> tail -> tail
forall sh. Shape sh => sh -> sh -> sh
intersectDim tail
sh1 tail
sh2 tail -> head -> tail :. head
forall tail head. tail -> head -> tail :. head
:. (head
i head -> head -> head
forall a. Num a => a -> a -> a
+ head
j)

        fnElem :: ((sh :. Int) -> e) -> ((sh :. Int) -> e) -> (sh :. Int) -> e
fnElem (sh :. Int) -> e
f1 (sh :. Int) -> e
f2 (sh
sh :. Int
i)
                | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n         = (sh :. Int) -> e
f1 (sh
sh sh -> Int -> sh :. Int
forall tail head. tail -> head -> tail :. head
:. Int
i)
                | Bool
otherwise     = (sh :. Int) -> e
f2 (sh
sh sh -> Int -> sh :. Int
forall tail head. tail -> head -> tail :. head
:. (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n))
{-# INLINE [2] append #-}


++ :: Array r1 (sh :. Int) e
-> Array r2 (sh :. Int) e -> Array D (sh :. Int) e
(++) Array r1 (sh :. Int) e
arr1 Array r2 (sh :. Int) e
arr2 = Array r1 (sh :. Int) e
-> Array r2 (sh :. Int) e -> Array D (sh :. Int) e
forall sh r1 e r2.
(Shape sh, Source r1 e, Source r2 e) =>
Array r1 (sh :. Int) e
-> Array r2 (sh :. Int) e -> Array D (sh :. Int) e
append Array r1 (sh :. Int) e
arr1 Array r2 (sh :. Int) e
arr2
{-# INLINE (++) #-}


-- | Transpose the lowest two dimensions of an array.
--      Transposing an array twice yields the original.
transpose
        :: (Shape sh, Source r e)
        => Array  r (sh :. Int :. Int) e
        -> Array  D (sh :. Int :. Int) e

transpose :: Array r ((sh :. Int) :. Int) e -> Array D ((sh :. Int) :. Int) e
transpose Array r ((sh :. Int) :. Int) e
arr
 = Array r ((sh :. Int) :. Int) e
-> (((sh :. Int) :. Int) -> (sh :. Int) :. Int)
-> ((((sh :. Int) :. Int) -> e) -> ((sh :. Int) :. Int) -> e)
-> Array D ((sh :. Int) :. Int) e
forall r sh sh' a b.
(Source r a, Shape sh) =>
Array r sh a
-> (sh -> sh') -> ((sh -> a) -> sh' -> b) -> Array D sh' b
unsafeTraverse Array r ((sh :. Int) :. Int) e
arr
        (\(sh
sh :. Int
m :. Int
n)        -> (sh
sh sh -> Int -> sh :. Int
forall tail head. tail -> head -> tail :. head
:. Int
n (sh :. Int) -> Int -> (sh :. Int) :. Int
forall tail head. tail -> head -> tail :. head
:.Int
m))
        (\((sh :. Int) :. Int) -> e
f -> \(sh
sh :. Int
i :. Int
j)  -> ((sh :. Int) :. Int) -> e
f (sh
sh sh -> Int -> sh :. Int
forall tail head. tail -> head -> tail :. head
:. Int
j (sh :. Int) -> Int -> (sh :. Int) :. Int
forall tail head. tail -> head -> tail :. head
:. Int
i))
{-# INLINE [2] transpose #-}


-- | Extract a sub-range of elements from an array.
extract :: (Shape sh, Source r e)
        => sh                   -- ^ Starting index.
        -> sh                   -- ^ Size of result.
        -> Array r sh e 
        -> Array D sh e
extract :: sh -> sh -> Array r sh e -> Array D sh e
extract sh
start sh
sz Array r sh e
arr
        = sh -> (sh -> e) -> Array D sh e
forall sh a. sh -> (sh -> a) -> Array D sh a
fromFunction sh
sz (\sh
ix -> Array r sh e
arr Array r sh e -> sh -> e
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`unsafeIndex` (sh -> sh -> sh
forall sh. Shape sh => sh -> sh -> sh
addDim sh
start sh
ix))
{-# INLINE [2] extract #-}


-- | Backwards permutation of an array's elements.
backpermute, unsafeBackpermute
        :: forall r sh1 sh2 e
        .  ( Shape sh1
           , Source r e)
        => sh2                  -- ^ Extent of result array.
        -> (sh2 -> sh1)         -- ^ Function mapping each index in the result array
                                --      to an index of the source array.
        -> Array r  sh1 e       -- ^ Source array.
        -> Array D  sh2 e

backpermute :: sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
backpermute sh2
newExtent sh2 -> sh1
perm Array r sh1 e
arr
        = Array r sh1 e
-> (sh1 -> sh2) -> ((sh1 -> e) -> sh2 -> e) -> Array D sh2 e
forall r sh sh' a b.
(Source r a, Shape sh) =>
Array r sh a
-> (sh -> sh') -> ((sh -> a) -> sh' -> b) -> Array D sh' b
traverse Array r sh1 e
arr (sh2 -> sh1 -> sh2
forall a b. a -> b -> a
const sh2
newExtent) ((sh1 -> e) -> (sh2 -> sh1) -> sh2 -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. sh2 -> sh1
perm)
{-# INLINE [2] backpermute #-}

unsafeBackpermute :: sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
unsafeBackpermute sh2
newExtent sh2 -> sh1
perm Array r sh1 e
arr
        = Array r sh1 e
-> (sh1 -> sh2) -> ((sh1 -> e) -> sh2 -> e) -> Array D sh2 e
forall r sh sh' a b.
(Source r a, Shape sh) =>
Array r sh a
-> (sh -> sh') -> ((sh -> a) -> sh' -> b) -> Array D sh' b
unsafeTraverse Array r sh1 e
arr (sh2 -> sh1 -> sh2
forall a b. a -> b -> a
const sh2
newExtent) ((sh1 -> e) -> (sh2 -> sh1) -> sh2 -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. sh2 -> sh1
perm)
{-# INLINE [2] unsafeBackpermute #-}


-- | Default backwards permutation of an array's elements.
--      If the function returns `Nothing` then the value at that index is taken
--      from the default array (@arrDft@)
backpermuteDft, unsafeBackpermuteDft
        :: forall r1 r2 sh1 sh2 e
        .  ( Shape sh1,   Shape sh2
           , Source r1 e, Source r2 e)
        => Array r2 sh2 e       -- ^ Default values (@arrDft@)
        -> (sh2 -> Maybe sh1)   -- ^ Function mapping each index in the result array
                                --      to an index in the source array.
        -> Array r1 sh1 e       -- ^ Source array.
        -> Array D  sh2 e

backpermuteDft :: Array r2 sh2 e
-> (sh2 -> Maybe sh1) -> Array r1 sh1 e -> Array D sh2 e
backpermuteDft Array r2 sh2 e
arrDft sh2 -> Maybe sh1
fnIndex Array r1 sh1 e
arrSrc
        = sh2 -> (sh2 -> e) -> Array D sh2 e
forall sh a. sh -> (sh -> a) -> Array D sh a
fromFunction (Array r2 sh2 e -> sh2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r2 sh2 e
arrDft) sh2 -> e
fnElem
        where   fnElem :: sh2 -> e
fnElem sh2
ix
                 = case sh2 -> Maybe sh1
fnIndex sh2
ix of
                        Just sh1
ix'        -> Array r1 sh1 e
arrSrc Array r1 sh1 e -> sh1 -> e
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`index` sh1
ix'
                        Maybe sh1
Nothing         -> Array r2 sh2 e
arrDft Array r2 sh2 e -> sh2 -> e
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`index` sh2
ix
{-# INLINE [2] backpermuteDft #-}

unsafeBackpermuteDft :: Array r2 sh2 e
-> (sh2 -> Maybe sh1) -> Array r1 sh1 e -> Array D sh2 e
unsafeBackpermuteDft Array r2 sh2 e
arrDft sh2 -> Maybe sh1
fnIndex Array r1 sh1 e
arrSrc
        = sh2 -> (sh2 -> e) -> Array D sh2 e
forall sh a. sh -> (sh -> a) -> Array D sh a
fromFunction (Array r2 sh2 e -> sh2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r2 sh2 e
arrDft) sh2 -> e
fnElem
        where   fnElem :: sh2 -> e
fnElem sh2
ix
                 = case sh2 -> Maybe sh1
fnIndex sh2
ix of
                        Just sh1
ix'        -> Array r1 sh1 e
arrSrc Array r1 sh1 e -> sh1 -> e
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`unsafeIndex` sh1
ix'
                        Maybe sh1
Nothing         -> Array r2 sh2 e
arrDft Array r2 sh2 e -> sh2 -> e
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`unsafeIndex` sh2
ix
{-# INLINE [2] unsafeBackpermuteDft #-}



-- | Extend an array, according to a given slice specification.
--
--   For example, to replicate the rows of an array use the following:
--
--   @extend (Any :. (5::Int) :. All) arr@
--
extend, unsafeExtend
        :: ( Slice sl
           , Shape (SliceShape sl)
           , Source r e)
        => sl
        -> Array r (SliceShape sl) e
        -> Array D (FullShape sl)  e

extend :: sl -> Array r (SliceShape sl) e -> Array D (FullShape sl) e
extend sl
sl Array r (SliceShape sl) e
arr
        = FullShape sl
-> (FullShape sl -> SliceShape sl)
-> Array r (SliceShape sl) e
-> Array D (FullShape sl) e
forall r sh1 sh2 e.
(Shape sh1, Source r e) =>
sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
backpermute
                (sl -> SliceShape sl -> FullShape sl
forall ss. Slice ss => ss -> SliceShape ss -> FullShape ss
fullOfSlice sl
sl (Array r (SliceShape sl) e -> SliceShape sl
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r (SliceShape sl) e
arr))
                (sl -> FullShape sl -> SliceShape sl
forall ss. Slice ss => ss -> FullShape ss -> SliceShape ss
sliceOfFull sl
sl)
                Array r (SliceShape sl) e
arr
{-# INLINE [2] extend #-}

unsafeExtend :: sl -> Array r (SliceShape sl) e -> Array D (FullShape sl) e
unsafeExtend sl
sl Array r (SliceShape sl) e
arr
        = FullShape sl
-> (FullShape sl -> SliceShape sl)
-> Array r (SliceShape sl) e
-> Array D (FullShape sl) e
forall r sh1 sh2 e.
(Shape sh1, Source r e) =>
sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
unsafeBackpermute
                (sl -> SliceShape sl -> FullShape sl
forall ss. Slice ss => ss -> SliceShape ss -> FullShape ss
fullOfSlice sl
sl (Array r (SliceShape sl) e -> SliceShape sl
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r (SliceShape sl) e
arr))
                (sl -> FullShape sl -> SliceShape sl
forall ss. Slice ss => ss -> FullShape ss -> SliceShape ss
sliceOfFull sl
sl)
                Array r (SliceShape sl) e
arr
{-# INLINE [2] unsafeExtend #-}



-- | Take a slice from an array, according to a given specification.
--
--   For example, to take a row from a matrix use the following:
--
--   @slice arr (Any :. (5::Int) :. All)@
--
--   To take a column use:
--
--   @slice arr (Any :. (5::Int))@
--
slice, unsafeSlice
        :: ( Slice sl
           , Shape (FullShape sl)
           , Source r e)
        => Array r (FullShape sl) e
        -> sl
        -> Array D (SliceShape sl) e

slice :: Array r (FullShape sl) e -> sl -> Array D (SliceShape sl) e
slice Array r (FullShape sl) e
arr sl
sl
        = SliceShape sl
-> (SliceShape sl -> FullShape sl)
-> Array r (FullShape sl) e
-> Array D (SliceShape sl) e
forall r sh1 sh2 e.
(Shape sh1, Source r e) =>
sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
backpermute
                (sl -> FullShape sl -> SliceShape sl
forall ss. Slice ss => ss -> FullShape ss -> SliceShape ss
sliceOfFull sl
sl (Array r (FullShape sl) e -> FullShape sl
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r (FullShape sl) e
arr))
                (sl -> SliceShape sl -> FullShape sl
forall ss. Slice ss => ss -> SliceShape ss -> FullShape ss
fullOfSlice sl
sl)
                Array r (FullShape sl) e
arr
{-# INLINE [2] slice #-}


unsafeSlice :: Array r (FullShape sl) e -> sl -> Array D (SliceShape sl) e
unsafeSlice Array r (FullShape sl) e
arr sl
sl
        = SliceShape sl
-> (SliceShape sl -> FullShape sl)
-> Array r (FullShape sl) e
-> Array D (SliceShape sl) e
forall r sh1 sh2 e.
(Shape sh1, Source r e) =>
sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
unsafeBackpermute
                (sl -> FullShape sl -> SliceShape sl
forall ss. Slice ss => ss -> FullShape ss -> SliceShape ss
sliceOfFull sl
sl (Array r (FullShape sl) e -> FullShape sl
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r (FullShape sl) e
arr))
                (sl -> SliceShape sl -> FullShape sl
forall ss. Slice ss => ss -> SliceShape ss -> FullShape ss
fullOfSlice sl
sl)
                Array r (FullShape sl) e
arr
{-# INLINE [2] unsafeSlice #-}