module Data.Array.Accelerate.AST (
  
  Idx(..),
  
  Val(..), prj, idxToInt,
  
  Arrays(..), ArraysR(..), 
  PreOpenAfun(..), OpenAfun, PreAfun, Afun, PreOpenAcc(..), OpenAcc(..), Acc,
  Stencil(..), StencilR(..),
  
  PreOpenFun(..), OpenFun, PreFun, Fun, PreOpenExp(..), OpenExp, PreExp, Exp, PrimConst(..),
  PrimFun(..)
) where
import Data.Typeable
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Tuple
import Data.Array.Accelerate.Array.Representation (SliceIndex)
import Data.Array.Accelerate.Array.Delayed        (Delayable)
import Data.Array.Accelerate.Array.Sugar          as Sugar
#include "accelerate.h"
data Idx env t where
  ZeroIdx ::              Idx (env, t) t
  SuccIdx :: Idx env t -> Idx (env, s) t
data Val env where
  Empty :: Val ()
  Push  :: Val env -> t -> Val (env, t)
deriving instance Typeable1 Val
prj :: Idx env t -> Val env -> t
prj ZeroIdx       (Push _   v) = v
prj (SuccIdx idx) (Push val _) = prj idx val
prj _             _            = INTERNAL_ERROR(error) "prj" "inconsistent valuation"
idxToInt :: Idx env t -> Int
idxToInt = go 0
  where go :: Int -> Idx env t -> Int
        go !n ZeroIdx       = n
        go !n (SuccIdx idx) = go (n+1) idx
class (Delayable arrs, Typeable arrs) => Arrays arrs where
  arrays :: ArraysR arrs
  
data ArraysR arrs where
  ArraysRunit  :: ArraysR ()
  ArraysRarray :: (Shape sh, Elt e) => ArraysR (Array sh e)
  ArraysRpair  :: ArraysR arrs1 -> ArraysR arrs2 -> ArraysR (arrs1, arrs2)
  
instance Arrays () where
  arrays = ArraysRunit
instance (Shape sh, Elt e) => Arrays (Array sh e) where
  arrays = ArraysRarray
instance (Arrays arrs1, Arrays arrs2) => Arrays (arrs1, arrs2) where
  arrays = ArraysRpair arrays arrays
data PreOpenAfun acc aenv t where
  Abody :: acc             aenv       t -> PreOpenAfun acc aenv t
  Alam  :: (Arrays as, Arrays t)
        => PreOpenAfun acc (aenv, as) t -> PreOpenAfun acc aenv (as -> t)
type OpenAfun = PreOpenAfun OpenAcc
type PreAfun acc = PreOpenAfun acc ()
type Afun = OpenAfun ()
data PreOpenAcc acc aenv a where
  
  
  Let         :: (Arrays bndArrs, Arrays bodyArrs)
              => acc            aenv            bndArrs               
              -> acc            (aenv, bndArrs) bodyArrs              
              -> PreOpenAcc acc aenv            bodyArrs
  
  Let2        :: (Arrays bndArrs1, Arrays bndArrs2, Arrays bodyArrs)
              => acc            aenv            (bndArrs1, bndArrs2)  
              -> acc            ((aenv, bndArrs1), bndArrs2)
                                                bodyArrs              
              -> PreOpenAcc acc aenv            bodyArrs
  PairArrays  :: (Shape sh1, Shape sh2, Elt e1, Elt e2)
              => acc            aenv (Array sh1 e1)
              -> acc            aenv (Array sh2 e2)
              -> PreOpenAcc acc aenv (Array sh1 e1, Array sh2 e2)
  
  Avar        :: Arrays arrs
              => Idx            aenv arrs
              -> PreOpenAcc acc aenv arrs
  
  Apply       :: (Arrays arrs1, Arrays arrs2)
              => PreAfun    acc      (arrs1 -> arrs2)
              -> acc            aenv arrs1
              -> PreOpenAcc acc aenv arrs2
  
  Acond       :: (Arrays arrs)
              => PreExp     acc aenv Bool
              -> acc            aenv arrs
              -> acc            aenv arrs
              -> PreOpenAcc acc aenv arrs
  
  Use         :: Array dim e
              -> PreOpenAcc acc aenv (Array dim e)
  
  Unit        :: Elt e
              => PreExp     acc aenv e
              -> PreOpenAcc acc aenv (Scalar e)
  
  
  Reshape     :: (Shape sh, Shape sh', Elt e)
              => PreExp     acc aenv sh                  
              -> acc            aenv (Array sh' e)       
              -> PreOpenAcc acc aenv (Array sh e)
  
  Generate    :: (Shape sh, Elt e)
              => PreExp     acc aenv sh                  
              -> PreFun     acc aenv (sh -> e)           
              -> PreOpenAcc acc aenv (Array sh e)
  
  
  Replicate   :: (Shape sh, Shape sl, Elt slix, Elt e)
              => SliceIndex (EltRepr slix)               
                            (EltRepr sl)
                            co'
                            (EltRepr sh)
              -> PreExp     acc aenv slix                
              -> acc            aenv (Array sl e)        
              -> PreOpenAcc acc aenv (Array sh e)
  
  
  Index       :: (Shape sh, Shape sl, Elt slix, Elt e)
              => SliceIndex (EltRepr slix)       
                            (EltRepr sl)
                            co'
                            (EltRepr sh)
              -> acc            aenv (Array sh e)        
              -> PreExp     acc aenv slix                
              -> PreOpenAcc acc aenv (Array sl e)
  
  Map         :: (Shape sh, Elt e, Elt e')
              => PreFun     acc aenv (e -> e')
              -> acc            aenv (Array sh e)
              -> PreOpenAcc acc aenv (Array sh e')
  
  
  
  ZipWith     :: (Shape sh, Elt e1, Elt e2, Elt e3)
              => PreFun     acc aenv (e1 -> e2 -> e3)
              -> acc            aenv (Array sh e1)
              -> acc            aenv (Array sh e2)
              -> PreOpenAcc acc aenv (Array sh e3)
  
  Fold        :: (Shape sh, Elt e)
              => PreFun     acc aenv (e -> e -> e)           
              -> PreExp     acc aenv e                       
              -> acc            aenv (Array (sh:.Int) e)     
              -> PreOpenAcc acc aenv (Array sh e)
  
  Fold1       :: (Shape sh, Elt e)
              => PreFun     acc aenv (e -> e -> e)           
              -> acc            aenv (Array (sh:.Int) e)     
              -> PreOpenAcc acc aenv (Array sh e)
  
  FoldSeg     :: (Shape sh, Elt e)
              => PreFun     acc aenv (e -> e -> e)           
              -> PreExp     acc aenv e                       
              -> acc            aenv (Array (sh:.Int) e)     
              -> acc            aenv Segments                
              -> PreOpenAcc acc aenv (Array (sh:.Int) e)
  
  Fold1Seg    :: (Shape sh, Elt e)
              => PreFun     acc aenv (e -> e -> e)           
              -> acc            aenv (Array (sh:.Int) e)     
              -> acc            aenv Segments                
              -> PreOpenAcc acc aenv (Array (sh:.Int) e)
  
  
  
  Scanl       :: Elt e
              => PreFun     acc aenv (e -> e -> e)           
              -> PreExp     acc aenv e                       
              -> acc            aenv (Vector e)              
              -> PreOpenAcc acc aenv (Vector e)
    
  
  
  Scanl'      :: Elt e
              => PreFun     acc aenv (e -> e -> e)           
              -> PreExp     acc aenv e                       
              -> acc            aenv (Vector e)              
              -> PreOpenAcc acc aenv (Vector e, Scalar e)
  
  Scanl1      :: Elt e
              => PreFun     acc aenv (e -> e -> e)           
              -> acc            aenv (Vector e)              
              -> PreOpenAcc acc aenv (Vector e)
  
  Scanr       :: Elt e
              => PreFun     acc aenv (e -> e -> e)           
              -> PreExp     acc aenv e                       
              -> acc            aenv (Vector e)              
              -> PreOpenAcc acc aenv (Vector e)
  
  Scanr'      :: Elt e
              => PreFun     acc aenv (e -> e -> e)           
              -> PreExp     acc aenv e                       
              -> acc            aenv (Vector e)              
              -> PreOpenAcc acc aenv (Vector e, Scalar e)
  
  Scanr1      :: Elt e
              => PreFun     acc aenv (e -> e -> e)           
              -> acc            aenv (Vector e)              
              -> PreOpenAcc acc aenv (Vector e)
  
  
  
  
  
  
  
  
  
  
  
  
  Permute     :: (Shape sh, Elt e)
              => PreFun     acc aenv (e -> e -> e)        
              -> acc            aenv (Array sh' e)        
              -> PreFun     acc aenv (sh -> sh')          
              -> acc            aenv (Array sh e)         
              -> PreOpenAcc acc aenv (Array sh' e)
  
  
  Backpermute :: (Shape sh, Shape sh', Elt e)
              => PreExp     acc aenv sh'                  
              -> PreFun     acc aenv (sh' -> sh)          
              -> acc            aenv (Array sh e)         
              -> PreOpenAcc acc aenv (Array sh' e)
  
  
  Stencil :: (Elt e, Elt e', Stencil sh e stencil)
          => PreFun     acc aenv (stencil -> e')          
          -> Boundary            (EltRepr e)              
          -> acc            aenv (Array sh e)             
          -> PreOpenAcc acc aenv (Array sh e')
  
  Stencil2 :: (Elt e1, Elt e2, Elt e',
               Stencil sh e1 stencil1,
               Stencil sh e2 stencil2)
           => PreFun     acc aenv (stencil1 ->
                                   stencil2 -> e')        
           -> Boundary            (EltRepr e1)            
           -> acc            aenv (Array sh e1)           
           -> Boundary            (EltRepr e2)            
           -> acc            aenv (Array sh e2)           
           -> PreOpenAcc acc aenv (Array sh e')
newtype OpenAcc aenv t = OpenAcc (PreOpenAcc OpenAcc aenv t)
deriving instance Typeable2 OpenAcc
type Acc = OpenAcc ()
class (Shape sh, Elt e, IsTuple stencil) => Stencil sh e stencil where
  stencil       :: StencilR sh e stencil
  stencilAccess :: (sh -> e) -> sh -> stencil
data StencilR sh e pat where
  StencilRunit3 :: (Elt e)
                => StencilR DIM1 e (e,e,e)
  StencilRunit5 :: (Elt e)
                => StencilR DIM1 e (e,e,e,e,e)
  StencilRunit7 :: (Elt e)
                => StencilR DIM1 e (e,e,e,e,e,e,e)
  StencilRunit9 :: (Elt e)
                => StencilR DIM1 e (e,e,e,e,e,e,e,e,e)
  StencilRtup3  :: (Shape sh, Elt e)
                => StencilR sh e pat1
                -> StencilR sh e pat2
                -> StencilR sh e pat3
                -> StencilR (sh:.Int) e (pat1,pat2,pat3)
  StencilRtup5  :: (Shape sh, Elt e)
                => StencilR sh e pat1
                -> StencilR sh e pat2
                -> StencilR sh e pat3
                -> StencilR sh e pat4
                -> StencilR sh e pat5
                -> StencilR (sh:.Int) e (pat1,pat2,pat3,pat4,pat5)
  StencilRtup7  :: (Shape sh, Elt e)
                => StencilR sh e pat1
                -> StencilR sh e pat2
                -> StencilR sh e pat3
                -> StencilR sh e pat4
                -> StencilR sh e pat5
                -> StencilR sh e pat6
                -> StencilR sh e pat7
                -> StencilR (sh:.Int) e (pat1,pat2,pat3,pat4,pat5,pat6,pat7)
  StencilRtup9  :: (Shape sh, Elt e)
                => StencilR sh e pat1
                -> StencilR sh e pat2
                -> StencilR sh e pat3
                -> StencilR sh e pat4
                -> StencilR sh e pat5
                -> StencilR sh e pat6
                -> StencilR sh e pat7
                -> StencilR sh e pat8
                -> StencilR sh e pat9
                -> StencilR (sh:.Int) e (pat1,pat2,pat3,pat4,pat5,pat6,pat7,pat8,pat9)
instance Elt e => Stencil DIM1 e (e, e, e) where
  stencil = StencilRunit3
  stencilAccess rf (Z:.y) = (rf' (y  1),
                             rf' y      ,
                             rf' (y + 1))
    where
      rf' d = rf (Z:.d)
instance Elt e => Stencil DIM1 e (e, e, e, e, e) where
  stencil = StencilRunit5
  stencilAccess rf (Z:.y) = (rf' (y  2),
                             rf' (y  1),
                             rf' y      ,
                             rf' (y + 1),
                             rf' (y + 2))
    where
      rf' d = rf (Z:.d)
instance Elt e => Stencil DIM1 e (e, e, e, e, e, e, e) where
  stencil = StencilRunit7
  stencilAccess rf (Z:.y) = (rf' (y  3),
                             rf' (y  2),
                             rf' (y  1),
                             rf' y      ,
                             rf' (y + 1),
                             rf' (y + 2),
                             rf' (y + 3))
    where
      rf' d = rf (Z:.d)
instance Elt e => Stencil DIM1 e (e, e, e, e, e, e, e, e, e) where
  stencil = StencilRunit9
  stencilAccess rf (Z:.y) = (rf' (y  4),
                             rf' (y  3),
                             rf' (y  2),
                             rf' (y  1),
                             rf' y      ,
                             rf' (y + 1),
                             rf' (y + 2),
                             rf' (y + 3),
                             rf' (y + 4))
    where
      rf' d = rf (Z:.d)
instance (Stencil (sh:.Int) a row1,
          Stencil (sh:.Int) a row2,
          Stencil (sh:.Int) a row3) => Stencil (sh:.Int:.Int) a (row1, row2, row3) where
  stencil = StencilRtup3 stencil stencil stencil
  stencilAccess rf (ix:.i) = (stencilAccess (rf' (i  1)) ix,
                              stencilAccess (rf'  i     ) ix,
                              stencilAccess (rf' (i + 1)) ix)
    where
      rf' d ds = rf (ds :. d)
instance (Stencil (sh:.Int) a row1,
          Stencil (sh:.Int) a row2,
          Stencil (sh:.Int) a row3,
          Stencil (sh:.Int) a row4,
          Stencil (sh:.Int) a row5) => Stencil (sh:.Int:.Int) a (row1, row2, row3, row4, row5) where
  stencil = StencilRtup5 stencil stencil stencil stencil stencil
  stencilAccess rf (ix:.i) = (stencilAccess (rf' (i  2)) ix,
                              stencilAccess (rf' (i  1)) ix,
                              stencilAccess (rf'  i     ) ix,
                              stencilAccess (rf' (i + 1)) ix,
                              stencilAccess (rf' (i + 2)) ix)
    where
      rf' d ds = rf (ds :. d)
instance (Stencil (sh:.Int) a row1,
          Stencil (sh:.Int) a row2,
          Stencil (sh:.Int) a row3,
          Stencil (sh:.Int) a row4,
          Stencil (sh:.Int) a row5,
          Stencil (sh:.Int) a row6,
          Stencil (sh:.Int) a row7)
  => Stencil (sh:.Int:.Int) a (row1, row2, row3, row4, row5, row6, row7) where
  stencil = StencilRtup7 stencil stencil stencil stencil stencil stencil stencil
  stencilAccess rf (ix:.i) = (stencilAccess (rf' (i  3)) ix,
                              stencilAccess (rf' (i  2)) ix,
                              stencilAccess (rf' (i  1)) ix,
                              stencilAccess (rf'  i     ) ix,
                              stencilAccess (rf' (i + 1)) ix,
                              stencilAccess (rf' (i + 2)) ix,
                              stencilAccess (rf' (i + 3)) ix)
    where
      rf' d ds = rf (ds :. d)
instance (Stencil (sh:.Int) a row1,
          Stencil (sh:.Int) a row2,
          Stencil (sh:.Int) a row3,
          Stencil (sh:.Int) a row4,
          Stencil (sh:.Int) a row5,
          Stencil (sh:.Int) a row6,
          Stencil (sh:.Int) a row7,
          Stencil (sh:.Int) a row8,
          Stencil (sh:.Int) a row9)
  => Stencil (sh:.Int:.Int) a (row1, row2, row3, row4, row5, row6, row7, row8, row9) where
  stencil = StencilRtup9 stencil stencil stencil stencil stencil stencil stencil stencil stencil
  stencilAccess rf (ix:.i) = (stencilAccess (rf' (i  4)) ix,
                              stencilAccess (rf' (i  3)) ix,
                              stencilAccess (rf' (i  2)) ix,
                              stencilAccess (rf' (i  1)) ix,
                              stencilAccess (rf'  i     ) ix,
                              stencilAccess (rf' (i + 1)) ix,
                              stencilAccess (rf' (i + 2)) ix,
                              stencilAccess (rf' (i + 3)) ix,
                              stencilAccess (rf' (i + 4)) ix)
    where
      rf' d ds = rf (ds :. d)
data PreOpenFun (acc :: * -> * -> *) env aenv t where
  Body :: PreOpenExp acc env              aenv t -> PreOpenFun acc env aenv t
  Lam  :: Elt a
       => PreOpenFun acc (env, EltRepr a) aenv t -> PreOpenFun acc env aenv (a -> t)
type OpenFun = PreOpenFun OpenAcc
type PreFun acc = PreOpenFun acc ()
type Fun = OpenFun ()
data PreOpenExp (acc :: * -> * -> *) env aenv t where
  
  Var         :: Elt t
              => Idx env (EltRepr t)
              -> PreOpenExp acc env aenv t
  
  Const       :: Elt t
              => EltRepr t
              -> PreOpenExp acc env aenv t
              
  
  Tuple       :: (Elt t, IsTuple t)
              => Tuple (PreOpenExp acc env aenv) (TupleRepr t)
              -> PreOpenExp acc env aenv t
  Prj         :: (Elt t, IsTuple t)
              => TupleIdx (TupleRepr t) e
              -> PreOpenExp acc env aenv t
              -> PreOpenExp acc env aenv e
  
  IndexNil    :: PreOpenExp acc env aenv Z
  IndexCons   :: (Slice sl, Elt a)
              => PreOpenExp acc env aenv sl
              -> PreOpenExp acc env aenv a
              -> PreOpenExp acc env aenv (sl:.a)
  IndexHead   :: (Slice sl, Elt a)
              => PreOpenExp acc env aenv (sl:.a)
              -> PreOpenExp acc env aenv a
  IndexTail   :: (Slice sl, Elt a)
              => PreOpenExp acc env aenv (sl:.a)
              -> PreOpenExp acc env aenv sl
  IndexAny    :: Shape sh => PreOpenExp acc env aenv (Any sh)
  
  Cond        :: PreOpenExp acc env aenv Bool
              -> PreOpenExp acc env aenv t 
              -> PreOpenExp acc env aenv t 
              -> PreOpenExp acc env aenv t
  
  PrimConst   :: Elt t
              => PrimConst t 
              -> PreOpenExp acc env aenv t
  
  PrimApp     :: (Elt a, Elt r)
              => PrimFun (a -> r) 
              -> PreOpenExp acc env aenv a
              -> PreOpenExp acc env aenv r
  
  
  IndexScalar :: (Shape dim, Elt t)
              => acc                aenv (Array dim t)
              -> PreOpenExp acc env aenv dim 
              -> PreOpenExp acc env aenv t
  
  
  Shape       :: (Shape dim, Elt e)
              => acc                aenv (Array dim e) 
              -> PreOpenExp acc env aenv dim
  
  
  Size        :: (Shape dim, Elt e)
              => acc                aenv (Array dim e)
              -> PreOpenExp acc env aenv Int
type OpenExp = PreOpenExp OpenAcc
type PreExp acc = PreOpenExp acc ()
type Exp = OpenExp ()
data PrimConst ty where
  
  PrimMinBound  :: BoundedType a -> PrimConst a
  PrimMaxBound  :: BoundedType a -> PrimConst a
  
  PrimPi        :: FloatingType a -> PrimConst a
data PrimFun sig where
  
  PrimAdd  :: NumType a -> PrimFun ((a, a) -> a)
  PrimSub  :: NumType a -> PrimFun ((a, a) -> a)
  PrimMul  :: NumType a -> PrimFun ((a, a) -> a)
  PrimNeg  :: NumType a -> PrimFun (a      -> a)
  PrimAbs  :: NumType a -> PrimFun (a      -> a)
  PrimSig  :: NumType a -> PrimFun (a      -> a)
  
  PrimQuot     :: IntegralType a -> PrimFun ((a, a)   -> a)
  PrimRem      :: IntegralType a -> PrimFun ((a, a)   -> a)
  PrimIDiv     :: IntegralType a -> PrimFun ((a, a)   -> a)
  PrimMod      :: IntegralType a -> PrimFun ((a, a)   -> a)
  PrimBAnd     :: IntegralType a -> PrimFun ((a, a)   -> a)
  PrimBOr      :: IntegralType a -> PrimFun ((a, a)   -> a)
  PrimBXor     :: IntegralType a -> PrimFun ((a, a)   -> a)
  PrimBNot     :: IntegralType a -> PrimFun (a        -> a)
  PrimBShiftL  :: IntegralType a -> PrimFun ((a, Int) -> a)
  PrimBShiftR  :: IntegralType a -> PrimFun ((a, Int) -> a)
  PrimBRotateL :: IntegralType a -> PrimFun ((a, Int) -> a)
  PrimBRotateR :: IntegralType a -> PrimFun ((a, Int) -> a)
  
  
  PrimFDiv        :: FloatingType a -> PrimFun ((a, a) -> a)
  PrimRecip       :: FloatingType a -> PrimFun (a      -> a)
  PrimSin         :: FloatingType a -> PrimFun (a      -> a)
  PrimCos         :: FloatingType a -> PrimFun (a      -> a)
  PrimTan         :: FloatingType a -> PrimFun (a      -> a)
  PrimAsin        :: FloatingType a -> PrimFun (a      -> a)
  PrimAcos        :: FloatingType a -> PrimFun (a      -> a)
  PrimAtan        :: FloatingType a -> PrimFun (a      -> a)
  PrimAsinh       :: FloatingType a -> PrimFun (a      -> a)
  PrimAcosh       :: FloatingType a -> PrimFun (a      -> a)
  PrimAtanh       :: FloatingType a -> PrimFun (a      -> a)
  PrimExpFloating :: FloatingType a -> PrimFun (a      -> a)
  PrimSqrt        :: FloatingType a -> PrimFun (a      -> a)
  PrimLog         :: FloatingType a -> PrimFun (a      -> a)
  PrimFPow        :: FloatingType a -> PrimFun ((a, a) -> a)
  PrimLogBase     :: FloatingType a -> PrimFun ((a, a) -> a)
  PrimAtan2       :: FloatingType a -> PrimFun ((a, a) -> a)
  PrimTruncate    :: FloatingType a -> IntegralType b -> PrimFun (a -> b)
  PrimRound       :: FloatingType a -> IntegralType b -> PrimFun (a -> b)
  PrimFloor       :: FloatingType a -> IntegralType b -> PrimFun (a -> b)
  PrimCeiling     :: FloatingType a -> IntegralType b -> PrimFun (a -> b)
  
  
  PrimLt   :: ScalarType a -> PrimFun ((a, a) -> Bool)
  PrimGt   :: ScalarType a -> PrimFun ((a, a) -> Bool)
  PrimLtEq :: ScalarType a -> PrimFun ((a, a) -> Bool)
  PrimGtEq :: ScalarType a -> PrimFun ((a, a) -> Bool)
  PrimEq   :: ScalarType a -> PrimFun ((a, a) -> Bool)
  PrimNEq  :: ScalarType a -> PrimFun ((a, a) -> Bool)
  PrimMax  :: ScalarType a -> PrimFun ((a, a) -> a   )
  PrimMin  :: ScalarType a -> PrimFun ((a, a) -> a   )
  
  PrimLAnd :: PrimFun ((Bool, Bool) -> Bool)
  PrimLOr  :: PrimFun ((Bool, Bool) -> Bool)
  PrimLNot :: PrimFun (Bool         -> Bool)
  
  PrimOrd  :: PrimFun (Char -> Int)
  PrimChr  :: PrimFun (Int  -> Char)
  
  
  
  
  PrimBoolToInt    :: PrimFun (Bool -> Int)
  PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b)