{-# LANGUAGE AllowAmbiguousTypes  #-}
{-# LANGUAGE DeriveAnyClass       #-}
{-# LANGUAGE DeriveGeneric        #-}
{-# LANGUAGE FlexibleContexts     #-}
{-# LANGUAGE FlexibleInstances    #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE TypeApplications     #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.Sugar.Shape
  where
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Representation.Tag
import Data.Array.Accelerate.Representation.Type
import qualified Data.Array.Accelerate.Representation.Shape         as R
import qualified Data.Array.Accelerate.Representation.Slice         as R
import Data.Kind
import GHC.Generics
type DIM0 = Z
type DIM1 = DIM0 :. Int
type DIM2 = DIM1 :. Int
type DIM3 = DIM2 :. Int
type DIM4 = DIM3 :. Int
type DIM5 = DIM4 :. Int
type DIM6 = DIM5 :. Int
type DIM7 = DIM6 :. Int
type DIM8 = DIM7 :. Int
type DIM9 = DIM8 :. Int
data Z = Z
  deriving (Show, Eq, Generic, Elt)
infixl 3 :.
data tail :. head = !tail :. !head
  deriving (Eq, Generic)  
instance (Show sh, Show sz) => Show (sh :. sz) where
  showsPrec p (sh :. sz) =
    showsPrec p sh . showString " :. " . showsPrec p sz
data All = All
  deriving (Show, Eq, Generic, Elt)
data Any sh = Any
  deriving (Show, Eq, Generic)
data Split = Split
  deriving (Show, Eq)
data Divide sh = Divide
  deriving (Show, Eq)
rank :: forall sh. Shape sh => Int
rank = R.rank (shapeR @sh)
size :: forall sh. Shape sh => sh -> Int
size = R.size (shapeR @sh) . fromElt
empty :: forall sh. Shape sh => sh
empty = toElt $ R.empty (shapeR @sh)
intersect :: forall sh. Shape sh => sh -> sh -> sh
intersect x y = toElt $ R.intersect (shapeR @sh) (fromElt x) (fromElt y)
union :: forall sh. Shape sh => sh -> sh -> sh
union x y = toElt $ R.union (shapeR @sh) (fromElt x) (fromElt y)
toIndex :: forall sh. Shape sh
        => sh       
        -> sh       
        -> Int      
toIndex sh ix = R.toIndex (shapeR @sh) (fromElt sh) (fromElt ix)
fromIndex :: forall sh. Shape sh
          => sh     
          -> Int    
          -> sh     
fromIndex sh = toElt . R.fromIndex (shapeR @sh) (fromElt sh)
iter :: forall sh e. Shape sh
     => sh              
     -> (sh -> e)       
     -> (e -> e -> e)   
     -> e               
     -> e
iter sh f = R.iter (shapeR @sh) (fromElt sh) (f . toElt)
iter1 :: forall sh e. Shape sh
      => sh
      -> (sh -> e)
      -> (e -> e -> e)
      -> e
iter1 sh f = R.iter1 (shapeR @sh) (fromElt sh) (f . toElt)
rangeToShape :: forall sh. Shape sh => (sh, sh) -> sh
rangeToShape (u, v) = toElt $ R.rangeToShape (shapeR @sh) (fromElt u, fromElt v)
shapeToRange :: forall sh. Shape sh => sh -> (sh, sh)
shapeToRange ix =
  let (u, v) = R.shapeToRange (shapeR @sh) (fromElt ix)
   in (toElt u, toElt v)
shapeToList :: forall sh. Shape sh => sh -> [Int]
shapeToList = R.shapeToList (shapeR @sh) . fromElt
listToShape :: forall sh. Shape sh => [Int] -> sh
listToShape = toElt . R.listToShape (shapeR @sh)
listToShape' :: forall sh. Shape sh => [Int] -> Maybe sh
listToShape' = fmap toElt . R.listToShape' (shapeR @sh)
showShape :: Shape sh => sh -> String
showShape = foldr (\sh str -> str ++ " :. " ++ show sh) "Z" . shapeToList
sliceShape
    :: forall slix co sl dim. (Shape sl, Shape dim)
    => R.SliceIndex slix (EltR sl) co (EltR dim)
    -> dim
    -> sl
sliceShape slix = toElt . R.sliceShape slix . fromElt
enumSlices :: forall slix co sl dim. (Elt slix, Elt dim)
           => R.SliceIndex (EltR slix) sl co (EltR dim)
           -> dim    
           -> [slix] 
enumSlices slix = map toElt . R.enumSlices slix . fromElt
class (Elt sh, Elt (Any sh), FullShape sh ~ sh, CoSliceShape sh ~ sh, SliceShape sh ~ Z)
       => Shape sh where
  
  shapeR :: R.ShapeR (EltR sh)
  
  sliceAnyIndex  :: R.SliceIndex (EltR (Any sh)) (EltR sh) () (EltR sh)
  
  sliceNoneIndex :: R.SliceIndex (EltR sh) () (EltR sh) (EltR sh)
class (Elt sl, Shape (SliceShape sl), Shape (CoSliceShape sl), Shape (FullShape sl))
       => Slice sl where
  type SliceShape   sl :: Type    
  type CoSliceShape sl :: Type    
  type FullShape    sl :: Type    
  sliceIndex :: R.SliceIndex (EltR sl)
                             (EltR (SliceShape   sl))
                             (EltR (CoSliceShape sl))
                             (EltR (FullShape    sl))
class (Slice (DivisionSlice sl)) => Division sl where
  type DivisionSlice sl :: Type   
  slicesIndex :: slix ~ DivisionSlice sl
              => R.SliceIndex (EltR slix)
                              (EltR (SliceShape   slix))
                              (EltR (CoSliceShape slix))
                              (EltR (FullShape    slix))
instance (Elt t, Elt h) => Elt (t :. h) where
  type EltR (t :. h) = (EltR t, EltR h)
  eltR           = TupRpair (eltR @t) (eltR @h)
  tagsR          = [TagRpair t h | t <- tagsR @t, h <- tagsR @h]
  fromElt (t:.h) = (fromElt t, fromElt h)
  toElt (t, h)   = toElt t :. toElt h
instance Elt (Any Z)
instance Shape sh => Elt (Any (sh :. Int)) where
  type EltR (Any (sh :. Int)) = (EltR (Any sh), ())
  eltR      = TupRpair (eltR @(Any sh)) TupRunit
  tagsR     = [TagRpair t TagRunit | t <- tagsR @(Any sh)]
  fromElt _ = (fromElt (Any :: Any sh), ())
  toElt _   = Any
instance Shape Z where
  shapeR         = R.ShapeRz
  sliceAnyIndex  = R.SliceNil
  sliceNoneIndex = R.SliceNil
instance Shape sh => Shape (sh:.Int) where
  shapeR         = R.ShapeRsnoc (shapeR @sh)
  sliceAnyIndex  = R.SliceAll   (sliceAnyIndex  @sh)
  sliceNoneIndex = R.SliceFixed (sliceNoneIndex @sh)
instance Slice Z where
  type SliceShape   Z = Z
  type CoSliceShape Z = Z
  type FullShape    Z = Z
  sliceIndex = R.SliceNil
instance Slice sl => Slice (sl:.All) where
  type SliceShape   (sl:.All) = SliceShape   sl :. Int
  type CoSliceShape (sl:.All) = CoSliceShape sl
  type FullShape    (sl:.All) = FullShape    sl :. Int
  sliceIndex = R.SliceAll (sliceIndex @sl)
instance Slice sl => Slice (sl:.Int) where
  type SliceShape   (sl:.Int) = SliceShape   sl
  type CoSliceShape (sl:.Int) = CoSliceShape sl :. Int
  type FullShape    (sl:.Int) = FullShape    sl :. Int
  sliceIndex = R.SliceFixed (sliceIndex @sl)
instance Shape sh => Slice (Any sh) where
  type SliceShape   (Any sh) = sh
  type CoSliceShape (Any sh) = Z
  type FullShape    (Any sh) = sh
  sliceIndex = sliceAnyIndex @sh
instance Division Z where
  type DivisionSlice Z = Z
  slicesIndex = R.SliceNil
instance Division sl => Division (sl:.All) where
  type DivisionSlice (sl:.All) = DivisionSlice sl :. All
  slicesIndex = R.SliceAll (slicesIndex @sl)
instance Division sl => Division (sl:.Split) where
  type DivisionSlice (sl:.Split) = DivisionSlice sl :. Int
  slicesIndex = R.SliceFixed (slicesIndex @sl)
instance Shape sh => Division (Any sh) where
  type DivisionSlice (Any sh) = Any sh
  slicesIndex = sliceAnyIndex @sh
instance (Shape sh, Slice sh) => Division (Divide sh) where
  type DivisionSlice (Divide sh) = sh
  slicesIndex = sliceNoneIndex @sh