{-# LANGUAGE TypeOperators, ExplicitForAll, FlexibleContexts #-}

module Data.Array.Repa.Operators.IndexSpace
	( reshape
	, append, (++)
	, transpose
        , extract
	, backpermute,         unsafeBackpermute
	, backpermuteDft,      unsafeBackpermuteDft
        , extend,              unsafeExtend 
        , slice,               unsafeSlice)
where
import Data.Array.Repa.Index
import Data.Array.Repa.Slice
import Data.Array.Repa.Base
import Data.Array.Repa.Repr.Delayed
import Data.Array.Repa.Operators.Traversal
import Data.Array.Repa.Shape		as S
import Prelude				hiding ((++))
import qualified Prelude		as P

stage	= "Data.Array.Repa.Operators.IndexSpace"

-- Index space transformations ------------------------------------------------
-- | Impose a new shape on the elements of an array.
--   The new extent must be the same size as the original, else `error`.
reshape	:: ( Shape sh1, Shape sh2
           , Source r1 e)
	=> sh2
	-> Array r1 sh1 e
	-> Array D  sh2 e

reshape sh2 arr
	| not $ S.size sh2 == S.size (extent arr)
	= error 
        $ stage P.++ ".reshape: reshaped array will not match size of the original"

reshape sh2 arr
        = fromFunction sh2 
        $ unsafeIndex arr . fromIndex (extent arr) . toIndex sh2
{-# INLINE [2] reshape #-}
 

-- | Append two arrays.
append, (++)
	:: ( Shape sh
           , Source r1 e, Source r2 e)
	=> Array r1 (sh :. Int) e
	-> Array r2 (sh :. Int) e
	-> Array D  (sh :. Int) e

append arr1 arr2
 = unsafeTraverse2 arr1 arr2 fnExtent fnElem
 where
 	(_ :. n) 	= extent arr1

	fnExtent (sh :. i) (_  :. j)
		= sh :. (i + j)

	fnElem f1 f2 (sh :. i)
      		| i < n		= f1 (sh :. i)
  		| otherwise	= f2 (sh :. (i - n))
{-# INLINE [2] append #-}


(++) arr1 arr2 = append arr1 arr2
{-# INLINE (++) #-}


-- | Transpose the lowest two dimensions of an array.
--	Transposing an array twice yields the original.
transpose
	:: (Shape sh, Source r e)
	=> Array  r (sh :. Int :. Int) e
	-> Array  D (sh :. Int :. Int) e

transpose arr
 = unsafeTraverse arr
	(\(sh :. m :. n) 	-> (sh :. n :.m))
	(\f -> \(sh :. i :. j) 	-> f (sh :. j :. i))
{-# INLINE [2] transpose #-}


-- | Extract a sub-range of elements from an array.
extract :: (Shape sh, Source r e)
        => sh                   -- ^ Starting index.
        -> sh                   -- ^ Size of result.
        -> Array r sh e 
        -> Array D sh e
extract start sz arr
        = fromFunction sz (\ix -> arr `unsafeIndex` (addDim start ix))
{-# INLINE [2] extract #-}


-- | Backwards permutation of an array's elements.
backpermute, unsafeBackpermute
	:: forall r sh1 sh2 e
        .  ( Shape sh1, Shape sh2
	   , Source r e)
	=> sh2 			-- ^ Extent of result array.
	-> (sh2 -> sh1) 	-- ^ Function mapping each index in the result array
				--	to an index of the source array.
	-> Array r  sh1 e 	-- ^ Source array.
	-> Array D  sh2 e

backpermute newExtent perm arr
	= traverse arr (const newExtent) (. perm)
{-# INLINE [2] backpermute #-}

unsafeBackpermute newExtent perm arr
        = unsafeTraverse arr (const newExtent) (. perm)
{-# INLINE [2] unsafeBackpermute #-}


-- | Default backwards permutation of an array's elements.
--	If the function returns `Nothing` then the value at that index is taken
--	from the default array (@arrDft@)
backpermuteDft, unsafeBackpermuteDft
	:: forall r1 r2 sh1 sh2 e
        .  ( Shape sh1,   Shape sh2
           , Source r1 e, Source r2 e)
	=> Array r2 sh2 e	-- ^ Default values (@arrDft@)
	-> (sh2 -> Maybe sh1) 	-- ^ Function mapping each index in the result array
				--	to an index in the source array.
	-> Array r1 sh1 e	-- ^ Source array.
	-> Array D  sh2 e

backpermuteDft arrDft fnIndex arrSrc
	= fromFunction (extent arrDft) fnElem
	where	fnElem ix
		 = case fnIndex ix of
			Just ix'	-> arrSrc `index` ix'
			Nothing		-> arrDft `index` ix
{-# INLINE [2] backpermuteDft #-}

unsafeBackpermuteDft arrDft fnIndex arrSrc
        = fromFunction (extent arrDft) fnElem
        where   fnElem ix
                 = case fnIndex ix of
                        Just ix'        -> arrSrc `unsafeIndex` ix'
                        Nothing         -> arrDft `unsafeIndex` ix
{-# INLINE [2] unsafeBackpermuteDft #-}



-- | Extend an array, according to a given slice specification.
--
--   For example, to replicate the rows of an array use the following:
--
--   @extend arr (Any :. (5::Int) :. All)@
--
extend, unsafeExtend
        :: ( Slice sl
           , Shape (SliceShape sl)
           , Shape (FullShape sl)
           , Source r e)
        => sl
        -> Array r (SliceShape sl) e
        -> Array D (FullShape sl)  e

extend sl arr
        = backpermute
                (fullOfSlice sl (extent arr))
                (sliceOfFull sl)
                arr
{-# INLINE [2] extend #-}

unsafeExtend sl arr
        = unsafeBackpermute
                (fullOfSlice sl (extent arr))
                (sliceOfFull sl)
                arr
{-# INLINE [2] unsafeExtend #-}



-- | Take a slice from an array, according to a given specification.
--
--   For example, to take a row from a matrix use the following:
--
--   @slice arr (Any :. (5::Int) :. All)@
--
--   To take a column use:
--
--   @slice arr (Any :. (5::Int))@
--
slice, unsafeSlice
        :: ( Slice sl
           , Shape (FullShape sl)
           , Shape (SliceShape sl)
           , Source r e)
        => Array r (FullShape sl) e
        -> sl
        -> Array D (SliceShape sl) e

slice arr sl
        = backpermute
                (sliceOfFull sl (extent arr))
                (fullOfSlice sl)
                arr
{-# INLINE [2] slice #-}


unsafeSlice arr sl
        = unsafeBackpermute
                (sliceOfFull sl (extent arr))
                (fullOfSlice sl)
                arr
{-# INLINE [2] unsafeSlice #-}