module Data.Array.Accelerate.Array.Representation (
  
  Shape(..), Slice(..), SliceIndex(..),
) where
import Data.Array.Accelerate.Type
#include "accelerate.h"
class (Eq sh, Slice sh) => Shape sh where
  
  dim       :: sh -> Int             
  size      :: sh -> Int             
  
  intersect :: sh -> sh -> sh  
  ignore    :: sh              
  index     :: sh -> sh -> Int 
                               
  bound     :: sh -> sh -> Boundary e -> Either e sh
                               
  iter      :: sh -> (sh -> a) -> (a -> a -> a) -> a -> a
                               
                               
                               
                               
  iter1     :: sh -> (sh -> a) -> (a -> a -> a) -> a
                               
  
  rangeToShape :: (sh, sh) -> sh   
                                   
  shapeToRange :: sh -> (sh, sh)   
  
  
  shapeToList :: sh -> [Int]    
  listToShape :: [Int] -> sh    
instance Shape () where
  dim ()            = 0
  size ()           = 1
  
  () `intersect` () = ()
  ignore            = ()
  index () ()       = 0
  bound () () _     = Right ()
  iter  () f c e    = e `c` f ()
  iter1 () f _      = f ()
  
  rangeToShape ((), ()) = ()
  shapeToRange ()       = ((), ())
  shapeToList () = []
  listToShape [] = ()
  listToShape _  = INTERNAL_ERROR(error) "listToShape" "non-empty list when converting to unit"
instance Shape sh => Shape (sh, Int) where
  dim (sh, _)                       = dim sh + 1
  size (sh, sz)                     = size sh * sz
  
  (sh1, sz1) `intersect` (sh2, sz2) = (sh1 `intersect` sh2, sz1 `min` sz2)
  ignore                            = (ignore, 1)
  index (sh, sz) (ix, i)            = BOUNDS_CHECK(checkIndex) "index" i sz
                                    $ index sh ix * sz + i
  bound (sh, sz) (ix, i) bndy
    | i < 0                         = case bndy of
                                        Clamp      -> bound sh ix bndy `addDim` 0
                                        Mirror     -> bound sh ix bndy `addDim` (i)
                                        Wrap       -> bound sh ix bndy `addDim` (sz+i)
                                        Constant e -> Left e
    | i >= sz                       = case bndy of
                                        Clamp      -> bound sh ix bndy `addDim` (sz1)
                                        Mirror     -> bound sh ix bndy `addDim` (sz(isz+2))
                                        Wrap       -> bound sh ix bndy `addDim` (isz)
                                        Constant e -> Left e
    | otherwise                     = bound sh ix bndy `addDim` i
    where
      Right ds `addDim` d = Right (ds, d)
      Left e   `addDim` _ = Left e
  iter (sh, sz) f c r = iter sh (\ix -> iter' (ix,0)) c r
    where
      iter' (ix,i) | i >= sz   = r
                   | otherwise = f (ix,i) `c` iter' (ix,i+1)
  iter1 (_,  0)  _ _ = BOUNDS_ERROR(error) "iter1" "empty iteration space"
  iter1 (sh, sz) f c = iter1 sh (\ix -> iter1' (ix,0)) c
    where
      iter1' (ix,i) | i == sz1 = f (ix,i)
                    | otherwise = f (ix,i) `c` iter1' (ix,i+1)
  rangeToShape ((sh1, sz1), (sh2, sz2)) 
    = (rangeToShape (sh1, sh2), sz2  sz1 + 1)
  shapeToRange (sh, sz) 
    = let (low, high) = shapeToRange sh
      in 
      ((low, 0), (high, sz  1))
  shapeToList (sh,sz) = sz : shapeToList sh
  listToShape []      = INTERNAL_ERROR(error) "listToShape" "empty list when converting to Ix"
  listToShape (x:xs)  = (listToShape xs,x)
class Slice sl where
  type SliceShape    sl      
  type CoSliceShape  sl      
  type FullShape     sl      
    
  sliceIndex ::  sl -> SliceIndex sl (SliceShape sl) (CoSliceShape sl) (FullShape sl)
instance Slice () where
  type SliceShape    () = ()
  type CoSliceShape  () = ()
  type FullShape () = ()
  sliceIndex _ = SliceNil
instance Slice sl => Slice (sl, ()) where
  type SliceShape   (sl, ()) = (SliceShape sl, Int)
  type CoSliceShape (sl, ()) = CoSliceShape sl
  type FullShape    (sl, ()) = (FullShape sl, Int)
  sliceIndex _ = SliceAll (sliceIndex (undefined::sl))
instance Slice sl => Slice (sl, Int) where
  type SliceShape   (sl, Int) = SliceShape sl
  type CoSliceShape (sl, Int) = (CoSliceShape sl, Int)
  type FullShape    (sl, Int) = (FullShape sl, Int)
  sliceIndex _ = SliceFixed (sliceIndex (undefined::sl))
data SliceIndex ix slice coSlice sliceDim where
  SliceNil   :: SliceIndex () () () ()
  SliceAll   :: 
   SliceIndex ix slice co dim -> SliceIndex (ix, ()) (slice, Int) co (dim, Int)
  SliceFixed :: 
   SliceIndex ix slice co dim -> SliceIndex (ix, Int) slice (co, Int) (dim, Int)
instance Show (SliceIndex ix slice coSlice sliceDim) where
  show SliceNil          = "SliceNil"
  show (SliceAll rest)   = "SliceAll ("++ show rest ++ ")"
  show (SliceFixed rest) = "SliceFixed (" ++ show rest ++ ")"