{-# OPTIONS_HADDOCK hide #-}
{-# LANGUAGE NoMonomorphismRestriction, PatternGuards #-}

module Data.Array.Repa.Operators.Mapping
	( map
	, zipWith
	, (+^)
	, (-^)
	, (*^)
	, (/^))
where
import Data.Array.Repa.Internals.Elt
import Data.Array.Repa.Internals.Base
import Data.Array.Repa.Shape		as S
import qualified Data.Vector.Unboxed	as V
import qualified Prelude		as P
import Prelude				(($), (.), (+), (*), (+), (/), (-))

-- | Apply a worker function to each element of an array, yielding a new array with the same extent.
--
--   This is specialised for arrays of up to four regions, using more breaks fusion.
--
map	:: (Shape sh, Elt a, Elt b)
	=> (a -> b)
	-> Array sh a
	-> Array sh b

{-# INLINE map #-}
map f (Array sh regions)
 = Array sh (mapRegions regions)

 where	{-# INLINE mapRegions #-}
	mapRegions rs
	 = case rs of
		[]		 -> []
		[r]		 -> [mapRegion r]
		[r1, r2] 	 -> [mapRegion r1, mapRegion r2]
		[r1, r2, r3]	 -> [mapRegion r1, mapRegion r2, mapRegion r3]
		[r1, r2, r3, r4] -> [mapRegion r1, mapRegion r2, mapRegion r3, mapRegion r4]
		_		 -> mapRegions' rs

	mapRegions' rs
	 = case rs of
		[]		 -> []
		(r : rs')	 -> mapRegion r : mapRegions' rs'

	{-# INLINE mapRegion #-}
	mapRegion (Region range gen)
	 = Region range (mapGen gen)

	{-# INLINE mapGen #-}
	mapGen gen
	 = case gen of
		GenManifest vec
		 -> GenCursor
			P.id
			addDim
		 	(\ix -> f $ V.unsafeIndex vec $ S.toIndex sh ix)

		GenCursor makeCursor shiftCursor loadElem
		 -> GenCursor makeCursor shiftCursor (f . loadElem)


-- | Combine two arrays, element-wise, with a binary operator.
--	If the extent of the two array arguments differ,
--	then the resulting array's extent is their intersection.
--
zipWith :: (Shape sh, Elt a, Elt b, Elt c)
	=> (a -> b -> c)
	-> Array sh a
	-> Array sh b
	-> Array sh c

{-# INLINE zipWith #-}
zipWith f arr1 arr2
 	| Array sh2 [_] <- arr1
	, Array sh1 [ Region g21 (GenCursor make21 _ load21)
		    , Region g22 (GenCursor make22 _ load22)] <- arr2

	= let	{-# INLINE load21' #-}
		load21' ix	= f (arr1 `unsafeIndex` ix) (load21 $ make21 ix)

		{-# INLINE load22' #-}
		load22' ix	= f (arr1 `unsafeIndex` ix) (load22 $ make22 ix)

	  in	Array (S.intersectDim sh1 sh2)
		      [ Region g21 (GenCursor P.id addDim load21')
		      , Region g22 (GenCursor P.id addDim load22') ]

	| P.otherwise
	= let	{-# INLINE getElem' #-}
		getElem' ix	= f (arr1 `unsafeIndex` ix) (arr2 `unsafeIndex` ix)
	  in	fromFunction
			(S.intersectDim (extent arr1) (extent arr2))
			getElem'


{-# INLINE (+^) #-}
(+^)	= zipWith (+)

{-# INLINE (-^) #-}
(-^)	= zipWith (-)

{-# INLINE (*^) #-}
(*^)	= zipWith (*)

{-# INLINE (/^) #-}
(/^)	= zipWith (/)