module Data.Array.Repa.Repr.Partitioned
        ( P, Array (..)
        , Range(..)
        , inRange)
where
import Data.Array.Repa.Base
import Data.Array.Repa.Shape
import Data.Array.Repa.Eval
import Data.Array.Repa.Repr.Delayed


-- | Partitioned arrays.
--   The last partition takes priority
--
--   These are produced by Repa's support functions and allow arrays to be defined
--   using a different element function for each partition.
--
--   The basic idea is described in ``Efficient Parallel Stencil Convolution'',
--   Ben Lippmeier and Gabriele Keller, Haskell 2011 -- though the underlying
--   array representation has changed since this paper was published.
--
data P r1 r2

data Range sh
        = Range !sh !sh                      -- indices defining the range
                (sh -> Bool)                 -- predicate to check whether were in range

-- | Check whether an index is within the given range.
inRange :: Range sh -> sh -> Bool
inRange :: Range sh -> sh -> Bool
inRange (Range sh
_ sh
_ sh -> Bool
p) sh
ix
        = sh -> Bool
p sh
ix
{-# INLINE inRange #-}


-- Repr -----------------------------------------------------------------------
-- | Read elements from a partitioned array.
instance (Source r1 e, Source r2 e) => Source (P r1 r2) e where
 data Array (P r1 r2) sh e
        = APart !sh                          -- size of the whole array
                !(Range sh) !(Array r1 sh e) -- if in range use this array
                !(Array r2 sh e)             -- otherwise use this array


 index :: Array (P r1 r2) sh e -> sh -> e
index (APart _ range arr1 arr2) sh
ix
   | Range sh -> sh -> Bool
forall sh. Range sh -> sh -> Bool
inRange Range sh
range sh
ix   = Array r1 sh e -> sh -> e
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
index Array r1 sh e
arr1 sh
ix
   | Bool
otherwise          = Array r2 sh e -> sh -> e
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
index Array r2 sh e
arr2 sh
ix
 {-# INLINE index #-}

 linearIndex :: Array (P r1 r2) sh e -> Int -> e
linearIndex arr :: Array (P r1 r2) sh e
arr@(APart sh _ _ _) Int
ix
        = Array (P r1 r2) sh e -> sh -> e
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
index Array (P r1 r2) sh e
arr (sh -> e) -> sh -> e
forall a b. (a -> b) -> a -> b
$ sh -> Int -> sh
forall sh. Shape sh => sh -> Int -> sh
fromIndex sh
sh Int
ix
 {-# INLINE linearIndex #-}

 extent :: Array (P r1 r2) sh e -> sh
extent (APart sh _ _ _) 
        = sh
sh
 {-# INLINE extent #-}

 deepSeqArray :: Array (P r1 r2) sh e -> b -> b
deepSeqArray (APart sh range arr1 arr2) b
y
  = sh
sh sh -> b -> b
forall sh a. Shape sh => sh -> a -> a
`deepSeq` Range sh
range Range sh -> Array r1 sh e -> Array r1 sh e
forall sh b. Shape sh => Range sh -> b -> b
`deepSeqRange` Array r1 sh e
arr1 Array r1 sh e -> Array r2 sh e -> Array r2 sh e
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray` Array r2 sh e
arr2 Array r2 sh e -> b -> b
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray` b
y
 {-# INLINE deepSeqArray #-}


deepSeqRange :: Shape sh => Range sh -> b -> b
deepSeqRange :: Range sh -> b -> b
deepSeqRange (Range sh
ix sh
sz sh -> Bool
f) b
y
        = sh
ix sh -> b -> b
forall sh a. Shape sh => sh -> a -> a
`deepSeq` sh
sz sh -> b -> b
forall sh a. Shape sh => sh -> a -> a
`deepSeq` sh -> Bool
f (sh -> Bool) -> b -> b
`seq` b
y
{-# INLINE deepSeqRange #-}


-- Load -----------------------------------------------------------------------
instance (LoadRange r1 sh e, Load r2 sh e)
        => Load (P r1 r2) sh e where
 loadP :: Array (P r1 r2) sh e -> MVec r2 e -> IO ()
loadP (APart _ (Range ix sz _) arr1 arr2) MVec r2 e
marr
  = do  Array r1 sh e -> MVec r2 e -> sh -> sh -> IO ()
forall r1 sh e r2.
(LoadRange r1 sh e, Target r2 e) =>
Array r1 sh e -> MVec r2 e -> sh -> sh -> IO ()
loadRangeP Array r1 sh e
arr1 MVec r2 e
marr sh
ix sh
sz
        Array r2 sh e -> MVec r2 e -> IO ()
forall r1 sh e r2.
(Load r1 sh e, Target r2 e) =>
Array r1 sh e -> MVec r2 e -> IO ()
loadP Array r2 sh e
arr2 MVec r2 e
marr
 {-# INLINE loadP #-}

 loadS :: Array (P r1 r2) sh e -> MVec r2 e -> IO ()
loadS (APart _ (Range ix sz _) arr1 arr2) MVec r2 e
marr
  = do  Array r1 sh e -> MVec r2 e -> sh -> sh -> IO ()
forall r1 sh e r2.
(LoadRange r1 sh e, Target r2 e) =>
Array r1 sh e -> MVec r2 e -> sh -> sh -> IO ()
loadRangeS Array r1 sh e
arr1 MVec r2 e
marr sh
ix sh
sz
        Array r2 sh e -> MVec r2 e -> IO ()
forall r1 sh e r2.
(Load r1 sh e, Target r2 e) =>
Array r1 sh e -> MVec r2 e -> IO ()
loadS Array r2 sh e
arr2 MVec r2 e
marr
 {-# INLINE loadS #-}