{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeFamilies        #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.Representation.Slice
-- Copyright   : [2008..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Representation.Slice
  where

import Data.Array.Accelerate.Representation.Shape

import Language.Haskell.TH


-- | Class of slice representations (which are nested pairs)
--
class Slice sl where
  type SliceShape    sl      -- the projected slice
  type CoSliceShape  sl      -- the complement of the slice
  type FullShape     sl      -- the combined dimension
  sliceIndex :: SliceIndex sl (SliceShape sl) (CoSliceShape sl) (FullShape sl)

instance Slice () where
  type SliceShape    () = ()
  type CoSliceShape  () = ()
  type FullShape     () = ()
  sliceIndex :: SliceIndex () (SliceShape ()) (CoSliceShape ()) (FullShape ())
sliceIndex = SliceIndex () () () ()
SliceIndex () (SliceShape ()) (CoSliceShape ()) (FullShape ())
SliceNil

instance Slice sl => Slice (sl, ()) where
  type SliceShape   (sl, ()) = (SliceShape  sl, Int)
  type CoSliceShape (sl, ()) = CoSliceShape sl
  type FullShape    (sl, ()) = (FullShape   sl, Int)
  sliceIndex :: SliceIndex
  (sl, ())
  (SliceShape (sl, ()))
  (CoSliceShape (sl, ()))
  (FullShape (sl, ()))
sliceIndex = SliceIndex sl (SliceShape sl) (CoSliceShape sl) (FullShape sl)
-> SliceIndex
     (sl, ()) (SliceShape sl, Int) (CoSliceShape sl) (FullShape sl, Int)
forall ix slice co dim.
SliceIndex ix slice co dim
-> SliceIndex (ix, ()) (slice, Int) co (dim, Int)
SliceAll (Slice sl =>
SliceIndex sl (SliceShape sl) (CoSliceShape sl) (FullShape sl)
forall sl.
Slice sl =>
SliceIndex sl (SliceShape sl) (CoSliceShape sl) (FullShape sl)
sliceIndex @sl)

instance Slice sl => Slice (sl, Int) where
  type SliceShape   (sl, Int) = SliceShape sl
  type CoSliceShape (sl, Int) = (CoSliceShape sl, Int)
  type FullShape    (sl, Int) = (FullShape    sl, Int)
  sliceIndex :: SliceIndex
  (sl, Int)
  (SliceShape (sl, Int))
  (CoSliceShape (sl, Int))
  (FullShape (sl, Int))
sliceIndex = SliceIndex sl (SliceShape sl) (CoSliceShape sl) (FullShape sl)
-> SliceIndex
     (sl, Int)
     (SliceShape sl)
     (CoSliceShape sl, Int)
     (FullShape sl, Int)
forall ix slice co dim.
SliceIndex ix slice co dim
-> SliceIndex (ix, Int) slice (co, Int) (dim, Int)
SliceFixed (Slice sl =>
SliceIndex sl (SliceShape sl) (CoSliceShape sl) (FullShape sl)
forall sl.
Slice sl =>
SliceIndex sl (SliceShape sl) (CoSliceShape sl) (FullShape sl)
sliceIndex @sl)

-- |Generalised array index, which may index only in a subset of the dimensions
-- of a shape.
--
data SliceIndex ix slice coSlice sliceDim where
  SliceNil   :: SliceIndex () () () ()
  SliceAll   :: SliceIndex ix slice co dim -> SliceIndex (ix, ()) (slice, Int) co       (dim, Int)
  SliceFixed :: SliceIndex ix slice co dim -> SliceIndex (ix, Int) slice      (co, Int) (dim, Int)

instance Show (SliceIndex ix slice coSlice sliceDim) where
  show :: SliceIndex ix slice coSlice sliceDim -> String
show SliceIndex ix slice coSlice sliceDim
SliceNil          = String
"SliceNil"
  show (SliceAll SliceIndex ix slice coSlice dim
rest)   = String
"SliceAll (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ SliceIndex ix slice coSlice dim -> String
forall a. Show a => a -> String
show SliceIndex ix slice coSlice dim
rest String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
  show (SliceFixed SliceIndex ix slice co dim
rest) = String
"SliceFixed (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ SliceIndex ix slice co dim -> String
forall a. Show a => a -> String
show SliceIndex ix slice co dim
rest String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

-- | Project the shape of a slice from the full shape.
--
sliceShape :: forall slix co sl dim.
              SliceIndex slix sl co dim
           -> dim
           -> sl
sliceShape :: SliceIndex slix sl co dim -> dim -> sl
sliceShape SliceIndex slix sl co dim
SliceNil        ()      = ()
sliceShape (SliceAll   SliceIndex ix slice co dim
sl) (sh, n) = (SliceIndex ix slice co dim -> dim -> slice
forall slix co sl dim. SliceIndex slix sl co dim -> dim -> sl
sliceShape SliceIndex ix slice co dim
sl dim
sh, Int
n)
sliceShape (SliceFixed SliceIndex ix sl co dim
sl) (sh, _) = SliceIndex ix sl co dim -> dim -> sl
forall slix co sl dim. SliceIndex slix sl co dim -> dim -> sl
sliceShape SliceIndex ix sl co dim
sl dim
sh

sliceShapeR :: SliceIndex slix sl co dim -> ShapeR sl
sliceShapeR :: SliceIndex slix sl co dim -> ShapeR sl
sliceShapeR SliceIndex slix sl co dim
SliceNil        = ShapeR sl
ShapeR ()
ShapeRz
sliceShapeR (SliceAll SliceIndex ix slice co dim
sl)   = ShapeR slice -> ShapeR (slice, Int)
forall sh. ShapeR sh -> ShapeR (sh, Int)
ShapeRsnoc (ShapeR slice -> ShapeR (slice, Int))
-> ShapeR slice -> ShapeR (slice, Int)
forall a b. (a -> b) -> a -> b
$ SliceIndex ix slice co dim -> ShapeR slice
forall slix sl co dim. SliceIndex slix sl co dim -> ShapeR sl
sliceShapeR SliceIndex ix slice co dim
sl
sliceShapeR (SliceFixed SliceIndex ix sl co dim
sl) = SliceIndex ix sl co dim -> ShapeR sl
forall slix sl co dim. SliceIndex slix sl co dim -> ShapeR sl
sliceShapeR SliceIndex ix sl co dim
sl

sliceDomainR :: SliceIndex slix sl co dim -> ShapeR dim
sliceDomainR :: SliceIndex slix sl co dim -> ShapeR dim
sliceDomainR SliceIndex slix sl co dim
SliceNil        = ShapeR dim
ShapeR ()
ShapeRz
sliceDomainR (SliceAll SliceIndex ix slice co dim
sl)   = ShapeR dim -> ShapeR (dim, Int)
forall sh. ShapeR sh -> ShapeR (sh, Int)
ShapeRsnoc (ShapeR dim -> ShapeR (dim, Int))
-> ShapeR dim -> ShapeR (dim, Int)
forall a b. (a -> b) -> a -> b
$ SliceIndex ix slice co dim -> ShapeR dim
forall slix sl co dim. SliceIndex slix sl co dim -> ShapeR dim
sliceDomainR SliceIndex ix slice co dim
sl
sliceDomainR (SliceFixed SliceIndex ix sl co dim
sl) = ShapeR dim -> ShapeR (dim, Int)
forall sh. ShapeR sh -> ShapeR (sh, Int)
ShapeRsnoc (ShapeR dim -> ShapeR (dim, Int))
-> ShapeR dim -> ShapeR (dim, Int)
forall a b. (a -> b) -> a -> b
$ SliceIndex ix sl co dim -> ShapeR dim
forall slix sl co dim. SliceIndex slix sl co dim -> ShapeR dim
sliceDomainR SliceIndex ix sl co dim
sl

-- | Enumerate all slices within a given bound. The innermost dimension changes
-- most rapidly.
--
-- See 'Data.Array.Accelerate.Sugar.Slice.enumSlices' for an example.
--
enumSlices
    :: forall slix co sl dim.
       SliceIndex slix sl co dim
    -> dim
    -> [slix]
enumSlices :: SliceIndex slix sl co dim -> dim -> [slix]
enumSlices SliceIndex slix sl co dim
SliceNil        ()       = [()]
enumSlices (SliceAll   SliceIndex ix slice co dim
sl) (sh, _)  = [ (ix
sh', ()) | ix
sh' <- SliceIndex ix slice co dim -> dim -> [ix]
forall slix co sl dim. SliceIndex slix sl co dim -> dim -> [slix]
enumSlices SliceIndex ix slice co dim
sl dim
sh]
enumSlices (SliceFixed SliceIndex ix sl co dim
sl) (sh, n)  = [ (ix
sh', Int
i)  | ix
sh' <- SliceIndex ix sl co dim -> dim -> [ix]
forall slix co sl dim. SliceIndex slix sl co dim -> dim -> [slix]
enumSlices SliceIndex ix sl co dim
sl dim
sh, Int
i <- [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]

rnfSliceIndex :: SliceIndex ix slice co sh -> ()
rnfSliceIndex :: SliceIndex ix slice co sh -> ()
rnfSliceIndex SliceIndex ix slice co sh
SliceNil        = ()
rnfSliceIndex (SliceAll SliceIndex ix slice co dim
sh)   = SliceIndex ix slice co dim -> ()
forall ix slice co sh. SliceIndex ix slice co sh -> ()
rnfSliceIndex SliceIndex ix slice co dim
sh
rnfSliceIndex (SliceFixed SliceIndex ix slice co dim
sh) = SliceIndex ix slice co dim -> ()
forall ix slice co sh. SliceIndex ix slice co sh -> ()
rnfSliceIndex SliceIndex ix slice co dim
sh

liftSliceIndex :: SliceIndex ix slice co sh -> Q (TExp (SliceIndex ix slice co sh))
liftSliceIndex :: SliceIndex ix slice co sh -> Q (TExp (SliceIndex ix slice co sh))
liftSliceIndex SliceIndex ix slice co sh
SliceNil          = [|| SliceNil ||]
liftSliceIndex (SliceAll SliceIndex ix slice co dim
rest)   = [|| SliceAll $$(liftSliceIndex rest) ||]
liftSliceIndex (SliceFixed SliceIndex ix slice co dim
rest) = [|| SliceFixed $$(liftSliceIndex rest) ||]