{-# LANGUAGE TypeFamilies, TypeOperators, FlexibleInstances #-}


-- | Index space transformation between arrays and slices.
module Data.Array.Repa.Slice
        ( All           (..)
        , Any           (..)
        , FullShape
        , SliceShape
        , Slice         (..))
where
import Data.Array.Repa.Index
import Prelude                  hiding (replicate, drop)


-- | Select all indices at a certain position.
data All        = All


-- | Place holder for any possible shape.
data Any sh     = Any


-- | Map a type of the index in the full shape, to the type of the index in the slice.
type family FullShape ss
type instance FullShape Z               = Z
type instance FullShape (Any sh)        = sh
type instance FullShape (sl :. Int)     = FullShape sl :. Int
type instance FullShape (sl :. All)     = FullShape sl :. Int


-- | Map the type of an index in the slice, to the type of the index in the full shape.
type family SliceShape ss
type instance SliceShape Z              = Z
type instance SliceShape (Any sh)       = sh
type instance SliceShape (sl :. Int)    = SliceShape sl
type instance SliceShape (sl :. All)    = SliceShape sl :. Int


-- | Class of index types that can map to slices.
class Slice ss where
        -- | Map an index of a full shape onto an index of some slice.
        sliceOfFull     :: ss -> FullShape ss  -> SliceShape ss

        -- | Map an index of a slice onto an index of the full shape.
        fullOfSlice     :: ss -> SliceShape ss -> FullShape  ss


instance Slice Z  where
        {-# INLINE [1] sliceOfFull #-}
        sliceOfFull :: Z -> FullShape Z -> SliceShape Z
sliceOfFull Z
_ FullShape Z
_         = Z
SliceShape Z
Z

        {-# INLINE [1] fullOfSlice #-}
        fullOfSlice :: Z -> SliceShape Z -> FullShape Z
fullOfSlice Z
_ SliceShape Z
_         = Z
FullShape Z
Z


instance Slice (Any sh) where
        {-# INLINE [1] sliceOfFull #-}
        sliceOfFull :: Any sh -> FullShape (Any sh) -> SliceShape (Any sh)
sliceOfFull Any sh
_ FullShape (Any sh)
sh        = SliceShape (Any sh)
FullShape (Any sh)
sh

        {-# INLINE [1] fullOfSlice #-}
        fullOfSlice :: Any sh -> SliceShape (Any sh) -> FullShape (Any sh)
fullOfSlice Any sh
_ SliceShape (Any sh)
sh        = SliceShape (Any sh)
FullShape (Any sh)
sh


instance Slice sl => Slice (sl :. Int) where
        {-# INLINE [1] sliceOfFull #-}
        sliceOfFull :: (sl :. Int) -> FullShape (sl :. Int) -> SliceShape (sl :. Int)
sliceOfFull (sl
fsl :. Int
_) (ssl :. _)
                = sl -> FullShape sl -> SliceShape sl
forall ss. Slice ss => ss -> FullShape ss -> SliceShape ss
sliceOfFull sl
fsl FullShape sl
ssl

        {-# INLINE [1] fullOfSlice #-}
        fullOfSlice :: (sl :. Int) -> SliceShape (sl :. Int) -> FullShape (sl :. Int)
fullOfSlice (sl
fsl :. Int
n) SliceShape (sl :. Int)
ssl
                = sl -> SliceShape sl -> FullShape sl
forall ss. Slice ss => ss -> SliceShape ss -> FullShape ss
fullOfSlice sl
fsl SliceShape sl
SliceShape (sl :. Int)
ssl FullShape sl -> Int -> FullShape sl :. Int
forall tail head. tail -> head -> tail :. head
:. Int
n


instance Slice sl => Slice (sl :. All) where
        {-# INLINE [1] sliceOfFull #-}
        sliceOfFull :: (sl :. All) -> FullShape (sl :. All) -> SliceShape (sl :. All)
sliceOfFull (sl
fsl :. All
All) (ssl :. s)
                = sl -> FullShape sl -> SliceShape sl
forall ss. Slice ss => ss -> FullShape ss -> SliceShape ss
sliceOfFull sl
fsl FullShape sl
ssl SliceShape sl -> Int -> SliceShape sl :. Int
forall tail head. tail -> head -> tail :. head
:. Int
s

        {-# INLINE [1] fullOfSlice #-}
        fullOfSlice :: (sl :. All) -> SliceShape (sl :. All) -> FullShape (sl :. All)
fullOfSlice (sl
fsl :. All
All) (ssl :. s)
                = sl -> SliceShape sl -> FullShape sl
forall ss. Slice ss => ss -> SliceShape ss -> FullShape ss
fullOfSlice sl
fsl SliceShape sl
ssl FullShape sl -> Int -> FullShape sl :. Int
forall tail head. tail -> head -> tail :. head
:. Int
s