{-# LANGUAGE FlexibleContexts, TypeFamilies, RankNTypes, ScopedTypeVariables #-}
{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses #-}
{-# OPTIONS_GHC -fno-warn-missing-methods #-}

-- Module      : Data.Array.Accelerate.Language
-- Copyright   : [2009..2010] Manuel M T Chakravarty, Gabriele Keller, Sean Lee
-- License     : BSD3
--
-- Maintainer  : Manuel M T Chakravarty <chak@cse.unsw.edu.au>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- We use the dictionary view of overloaded operations (such as arithmetic and
-- bit manipulation) to reify such expressions.  With non-overloaded
-- operations (such as, the logical connectives) and partially overloaded
-- operations (such as comparisons), we use the standard operator names with a
-- '*' attached.  We keep the standard alphanumeric names as they can be
-- easily qualified.

module Data.Array.Accelerate.Language (

  -- ** Array and scalar expressions
  Acc, Exp,                                 -- re-exporting from 'Smart'
  
  -- ** Stencil specification
  Boundary(..), Stencil,                    -- re-exporting from 'Smart'

  -- ** Common stencil types
  Stencil3, Stencil5, Stencil7, Stencil9,
  Stencil3x3, Stencil5x3, Stencil3x5, Stencil5x5,
  Stencil3x3x3, Stencil5x3x3, Stencil3x5x3, Stencil3x3x5, Stencil5x5x3, Stencil5x3x5,
  Stencil3x5x5, Stencil5x5x5,

  -- ** Scalar introduction
  constant,                                 -- re-exporting from 'Smart'

  -- ** Array introduction
  use, unit,

  -- ** Shape manipulation
  reshape,

  -- ** Collective array operations
  slice, replicate, zip, unzip, map, zipWith, scanl, scanr, fold, foldSeg,
  permute, backpermute, stencil, stencil2,
  
  -- ** Tuple construction and destruction
  Tuple(..), fst, snd, curry, uncurry,
  
  -- ** Conditional expressions
  (?),
  
  -- ** Array operations with a scalar result
  (!), shape,
  
  -- ** Methods of H98 classes that we need to redefine as their signatures change
  (==*), (/=*), (<*), (<=*), (>*), (>=*), max, min,
  bit, setBit, clearBit, complementBit, testBit,
  shift,  shiftL,  shiftR,
  rotate, rotateL, rotateR,

  -- ** Standard functions that we need to redefine as their signatures change
  (&&*), (||*), not,
  
  -- ** Conversions
  boolToInt, intToFloat, roundFloatToInt, truncateFloatToInt,

  -- ** Constants
  ignore

  -- ** Instances of Bounded, Enum, Eq, Ord, Bits, Num, Real, Floating,
  --    Fractional, RealFrac, RealFloat

) where

-- avoid clashes with Prelude functions
import Prelude   hiding (replicate, zip, unzip, map, scanl, scanr, zipWith,
                         filter, max, min, not, const, fst, snd, curry, uncurry)

-- standard libraries
import Data.Bits (Bits((.&.), (.|.), xor, complement))

-- friends
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Array.Sugar hiding ((!), ignore, shape)
import qualified Data.Array.Accelerate.Array.Sugar as Sugar
import Data.Array.Accelerate.Smart


-- Collective operations
-- ---------------------

-- |Array inlet: makes an array available for processing using the Accelerate
-- language; triggers asynchronous host->device transfer if necessary.
--
use :: (Ix dim, Elem e) => Array dim e -> Acc (Array dim e)
use = Use

-- |Scalar inlet: injects a scalar (or a tuple of scalars) into a singleton
-- array for use in the Accelerate language.
--
unit :: Elem e => Exp e -> Acc (Scalar e)
unit = Unit

-- |Change the shape of an array without altering its contents, where
--
-- > precondition: size dim == size dim'
--
reshape :: (Ix dim, Ix dim', Elem e) 
        => Exp dim 
        -> Acc (Array dim' e) 
        -> Acc (Array dim e)
reshape = Reshape

-- |Replicate an array across one or more dimensions as specified by the
-- *generalised* array index provided as the first argument.
--
-- For example, assuming 'arr' is a vector (one-dimensional array),
--
-- > replicate (2, All, 3) arr
--
-- yields a three dimensional array, where 'arr' is replicated twice across the
-- first and three times across the third dimension.
--
replicate :: (SliceIx slix, Elem e) 
          => Exp slix 
          -> Acc (Array (Slice    slix) e) 
          -> Acc (Array (SliceDim slix) e)
replicate = Replicate

-- |Index an array with a *generalised* array index (supplied as the second
-- argument).  The result is a new array (possibly a singleton) containing
-- all dimensions in their entirety.
--
slice :: (SliceIx slix, Elem e) 
      => Acc (Array (SliceDim slix) e) 
      -> Exp slix 
      -> Acc (Array (Slice slix) e)
slice = Index

-- |Combine the elements of two arrays pairwise.  The shape of the result is 
-- the intersection of the two argument shapes.
--
zip :: (Ix dim, Elem a, Elem b) 
    => Acc (Array dim a)
    -> Acc (Array dim b)
    -> Acc (Array dim (a, b))
zip = zipWith (\x y -> tuple (x, y))

-- |The converse of 'zip', but the shape of the two results is identical to the
-- shape of the argument.
-- 
unzip :: (Ix dim, Elem a, Elem b)
      => Acc (Array dim (a, b))
      -> (Acc (Array dim a), Acc (Array dim b))
unzip arr = (map fst arr, map snd arr)

-- |Apply the given function elementwise to the given array.
-- 
map :: (Ix dim, Elem a, Elem b) 
    => (Exp a -> Exp b) 
    -> Acc (Array dim a)
    -> Acc (Array dim b)
map = Map

-- |Apply the given binary function elementwise to the two arrays.  The extent of the resulting
-- array is the intersection of the extents of the two source arrays.
--
zipWith :: (Ix dim, Elem a, Elem b, Elem c)
        => (Exp a -> Exp b -> Exp c) 
        -> Acc (Array dim a)
        -> Acc (Array dim b)
        -> Acc (Array dim c)
zipWith = ZipWith

-- |Prescan of a vector.  The type 'a' together with the binary function
-- (first argument) and value (second argument) must form a monoid; i.e., the
-- function must be /associative/ and the value must be its /neutral element/.
--
-- The resulting vector of prescan values has the same size as the argument 
-- vector.  The resulting scalar is the reduction value.
--
scanl :: Elem a
      => (Exp a -> Exp a -> Exp a)
      -> Exp a
      -> Acc (Vector a)
      -> (Acc (Vector a), Acc (Scalar a))
scanl f e arr = unpair (Scanl f e arr)

-- |The right-to-left dual of 'scanl'.
--
scanr :: Elem a
      => (Exp a -> Exp a -> Exp a)
      -> Exp a
      -> Acc (Vector a)
      -> (Acc (Vector a), Acc (Scalar a))
scanr f e arr = unpair (Scanr f e arr)

-- |Reduction of an array.  The type 'a' together with the binary function
-- (first argument) and value (second argument) must form a monoid; i.e., the 
-- function must be /associative/ and the value must be its /neutral element/.
-- 
fold :: (Ix dim, Elem a)
     => (Exp a -> Exp a -> Exp a) 
     -> Exp a 
     -> Acc (Array dim a)
     -> Acc (Scalar a)
fold = Fold

-- |Segmented reduction.
--
foldSeg :: Elem a 
        => (Exp a -> Exp a -> Exp a) 
        -> Exp a 
        -> Acc (Vector a)
        -> Acc Segments
        -> Acc (Vector a)
foldSeg = FoldSeg

-- |Forward permutation specified by an index mapping.  The result array is
-- initialised with the given defaults and any further values that are permuted
-- into the result array are added to the current value using the given
-- combination function.
--
-- The combination function must be /associative/.  Elements that are mapped to
-- the magic value 'ignore' by the permutation function are being dropped.
--
permute :: (Ix dim, Ix dim', Elem a)
        => (Exp a -> Exp a -> Exp a)    -- ^combination function
        -> Acc (Array dim' a)           -- ^array of default values
        -> (Exp dim -> Exp dim')        -- ^permutation
        -> Acc (Array dim  a)           -- ^permuted array
        -> Acc (Array dim' a)
permute = Permute

-- |Backward permutation 
--
backpermute :: (Ix dim, Ix dim', Elem a)
            => Exp dim'                 -- ^shape of the result array
            -> (Exp dim' -> Exp dim)    -- ^permutation
            -> Acc (Array dim  a)       -- ^permuted array
            -> Acc (Array dim' a)
backpermute = Backpermute


-- Common stencil types
--

-- DIM1 stencil type
type Stencil3 a = (Exp a, Exp a, Exp a)
type Stencil5 a = (Exp a, Exp a, Exp a, Exp a, Exp a)
type Stencil7 a = (Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a)
type Stencil9 a = (Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a, Exp a)

-- DIM2 stencil type
type Stencil3x3 a = (Stencil3 a, Stencil3 a, Stencil3 a)
type Stencil5x3 a = (Stencil5 a, Stencil5 a, Stencil5 a)
type Stencil3x5 a = (Stencil3 a, Stencil3 a, Stencil3 a, Stencil3 a, Stencil3 a)
type Stencil5x5 a = (Stencil5 a, Stencil5 a, Stencil5 a, Stencil5 a, Stencil5 a)

-- DIM3 stencil type
type Stencil3x3x3 a = (Stencil3x3 a, Stencil3x3 a, Stencil3x3 a)
type Stencil5x3x3 a = (Stencil5x3 a, Stencil5x3 a, Stencil5x3 a)
type Stencil3x5x3 a = (Stencil3x5 a, Stencil3x5 a, Stencil3x5 a)
type Stencil3x3x5 a = (Stencil3x3 a, Stencil3x3 a, Stencil3x3 a, Stencil3x3 a, Stencil3x3 a)
type Stencil5x5x3 a = (Stencil5x5 a, Stencil5x5 a, Stencil5x5 a)
type Stencil5x3x5 a = (Stencil5x3 a, Stencil5x3 a, Stencil5x3 a, Stencil5x3 a, Stencil5x3 a)
type Stencil3x5x5 a = (Stencil3x5 a, Stencil3x5 a, Stencil3x5 a, Stencil3x5 a, Stencil3x5 a)
type Stencil5x5x5 a = (Stencil5x5 a, Stencil5x5 a, Stencil5x5 a, Stencil5x5 a, Stencil5x5 a)

-- |Map a stencil over an array.  In contrast to 'map', the domain of a stencil function is an
--  entire /neighbourhood/ of each array element.  Neighbourhoods are sub-arrays centred around a
--  focal point.  They are not necessarily rectangular, but they are symmetric in each dimension
--  and have an extent of at least three in each dimensions — due to the symmetry requirement, the
--  extent is necessarily odd.  The focal point is the array position that is determined by the
--  stencil.
--
--  For those array positions where the neighbourhood extends past the boundaries of the source
--  array, a boundary condition determines the contents of the out-of-bounds neighbourhood
--  positions.
--
stencil :: (Ix dim, Elem a, Elem b, Stencil dim a stencil)
        => (stencil -> Exp b)                 -- ^stencil function
        -> Boundary a                         -- ^boundary condition
        -> Acc (Array dim a)                  -- ^source array
        -> Acc (Array dim b)                  -- ^destination array
stencil = Stencil

-- |Map a binary stencil of an array.  The extent of the resulting array is the intersection of
-- the extents of the two source arrays.
--
stencil2 :: (Ix dim, Elem a, Elem b, Elem c, 
             Stencil dim a stencil1, 
             Stencil dim b stencil2)
        => (stencil1 -> stencil2 -> Exp c)    -- ^binary stencil function
        -> Boundary a                         -- ^boundary condition #1
        -> Acc (Array dim a)                  -- ^source array #1
        -> Boundary b                         -- ^boundary condition #2
        -> Acc (Array dim b)                  -- ^source array #2
        -> Acc (Array dim c)                  -- ^destination array
stencil2 = Stencil2


-- Tuples
-- ------

class Tuple tup where
  type TupleT tup

  -- |Turn a tuple of scalar expressions into a scalar expressions that yields
  -- a tuple.
  -- 
  tuple   :: tup -> TupleT tup
  
  -- |Turn a scalar expression that yields a tuple into a tuple of scalar
  -- expressions.
  untuple :: TupleT tup -> tup
  
instance (Elem a, Elem b) => Tuple (Exp a, Exp b) where
  type TupleT (Exp a, Exp b) = Exp (a, b)
  tuple   = tup2
  untuple = untup2

instance (Elem a, Elem b, Elem c) => Tuple (Exp a, Exp b, Exp c) where
  type TupleT (Exp a, Exp b, Exp c) = Exp (a, b, c)
  tuple   = tup3
  untuple = untup3

instance (Elem a, Elem b, Elem c, Elem d) 
  => Tuple (Exp a, Exp b, Exp c, Exp d) where
  type TupleT (Exp a, Exp b, Exp c, Exp d) = Exp (a, b, c, d)
  tuple   = tup4
  untuple = untup4

instance (Elem a, Elem b, Elem c, Elem d, Elem e) 
  => Tuple (Exp a, Exp b, Exp c, Exp d, Exp e) where
  type TupleT (Exp a, Exp b, Exp c, Exp d, Exp e) = Exp (a, b, c, d, e)
  tuple   = tup5
  untuple = untup5

instance (Elem a, Elem b, Elem c, Elem d, Elem e, Elem f)
  => Tuple (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f) where
  type TupleT (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f)
    = Exp (a, b, c, d, e, f)
  tuple   = tup6
  untuple = untup6

instance (Elem a, Elem b, Elem c, Elem d, Elem e, Elem f, Elem g)
  => Tuple (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g) where
  type TupleT (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g)
    = Exp (a, b, c, d, e, f, g)
  tuple   = tup7
  untuple = untup7

instance (Elem a, Elem b, Elem c, Elem d, Elem e, Elem f, Elem g, Elem h)
  => Tuple (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h) where
  type TupleT (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h)
    = Exp (a, b, c, d, e, f, g, h)
  tuple   = tup8
  untuple = untup8

instance (Elem a, Elem b, Elem c, Elem d, Elem e, Elem f, Elem g, Elem h, Elem i)
  => Tuple (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i) where
  type TupleT (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i)
    = Exp (a, b, c, d, e, f, g, h, i)
  tuple   = tup9
  untuple = untup9


-- |Extract the first component of a pair
--
fst :: forall a b. (Elem a, Elem b) => Exp (a, b) -> Exp a
fst e = let (x, _:: Exp b) = untuple e in x

-- |Extract the second component of a pair
snd :: forall a b. (Elem a, Elem b) => Exp (a, b) -> Exp b
snd e = let (_ :: Exp a, y) = untuple e in y

-- |Converts an uncurried function to a curried function
--
curry :: (Elem a, Elem b) => (Exp (a,b) -> Exp c) -> Exp a -> Exp b -> Exp c
curry f x y = f (tuple (x,y))

-- |Converts a curried function to a function on pairs
--
uncurry :: (Elem a, Elem b) => (Exp a -> Exp b -> Exp c) -> Exp (a,b) -> Exp c
uncurry f t = let (x,y) = untuple t in f x y


-- Conditional expressions
-- -----------------------

-- |Conditional expression.
--
infix 0 ?
(?) :: Elem t => Exp Bool -> (Exp t, Exp t) -> Exp t
c ? (t, e) = Cond c t e


-- Array operations with a scalar result
-- -------------------------------------

-- |Expression form that extracts a scalar from an array.
--
infixl 9 !
(!) :: (Ix dim, Elem e) => Acc (Array dim e) -> Exp dim -> Exp e
(!) = IndexScalar

shape :: (Ix dim, Elem dim) => Acc (Array dim e) -> Exp dim
shape = Shape


-- Instances of all relevant H98 classes
-- -------------------------------------

instance (Elem t, IsBounded t) => Bounded (Exp t) where
  minBound = mkMinBound
  maxBound = mkMaxBound

instance (Elem t, IsScalar t) => Enum (Exp t)
--  succ = mkSucc
--  pred = mkPred
  -- FIXME: ops

instance (Elem t, IsScalar t) => Prelude.Eq (Exp t) where
  -- FIXME: instance makes no sense with standard signatures
  (==)        = error "Prelude.Eq.== applied to EDSL types"

instance (Elem t, IsScalar t) => Prelude.Ord (Exp t) where
  -- FIXME: instance makes no sense with standard signatures
  compare     = error "Prelude.Ord.compare applied to EDSL types"

instance (Elem t, IsNum t, IsIntegral t) => Bits (Exp t) where
  (.&.)      = mkBAnd
  (.|.)      = mkBOr
  xor        = mkBXor
  complement = mkBNot
  -- FIXME: argh, the rest have fixed types in their signatures

shift, shiftL, shiftR :: (Elem t, IsIntegral t) => Exp t -> Exp Int -> Exp t
shift  x i = i ==* 0 ? (x, i <* 0 ? (x `shiftR` (-i), x `shiftL` i))
shiftL     = mkBShiftL
shiftR     = mkBShiftR

rotate, rotateL, rotateR :: (Elem t, IsIntegral t) => Exp t -> Exp Int -> Exp t
rotate  x i = i ==* 0 ? (x, i <* 0 ? (x `rotateR` (-i), x `rotateL` i))
rotateL     = mkBRotateL
rotateR     = mkBRotateR

bit :: (Elem t, IsIntegral t) => Exp Int -> Exp t
bit x = 1 `shiftL` x

setBit, clearBit, complementBit :: (Elem t, IsIntegral t) => Exp t -> Exp Int -> Exp t
x `setBit` i        = x .|. bit i
x `clearBit` i      = x .&. complement (bit i)
x `complementBit` i = x `xor` bit i

testBit :: (Elem t, IsIntegral t) => Exp t -> Exp Int -> Exp Bool
x `testBit` i       = (x .&. bit i) /=* 0


instance (Elem t, IsNum t) => Num (Exp t) where
  (+)         = mkAdd
  (-)         = mkSub
  (*)         = mkMul
  negate      = mkNeg
  abs         = mkAbs
  signum      = mkSig
  fromInteger = constant . fromInteger

instance (Elem t, IsNum t) => Real (Exp t)
  -- FIXME: Why did we include this class?  We won't need `toRational' until
  --   we support rational numbers in AP computations.

instance (Elem t, IsIntegral t) => Integral (Exp t) where
  quot = mkQuot
  rem  = mkRem
  div  = mkIDiv
  mod  = mkMod
--  quotRem =
--  divMod  =
--  toInteger =  -- makes no sense

instance (Elem t, IsFloating t) => Floating (Exp t) where
  pi      = mkPi
  sin     = mkSin
  cos     = mkCos
  tan     = mkTan
  asin    = mkAsin
  acos    = mkAcos
  atan    = mkAtan
  asinh   = mkAsinh
  acosh   = mkAcosh
  atanh   = mkAtanh
  exp     = mkExpFloating
  sqrt    = mkSqrt
  log     = mkLog
  (**)    = mkFPow
  logBase = mkLogBase
  -- FIXME: add other ops

instance (Elem t, IsFloating t) => Fractional (Exp t) where
  (/)          = mkFDiv
  recip        = mkRecip
  fromRational = constant . fromRational
  -- FIXME: add other ops

instance (Elem t, IsFloating t) => RealFrac (Exp t)
  -- FIXME: add ops

instance (Elem t, IsFloating t) => RealFloat (Exp t) where
  atan2 = mkAtan2
  -- FIXME: add ops


-- Methods from H98 classes, where we need other signatures
-- --------------------------------------------------------

infix 4 ==*, /=*, <*, <=*, >*, >=*

-- |Equality lifted into Accelerate expressions.
--
(==*) :: (Elem t, IsScalar t) => Exp t -> Exp t -> Exp Bool
(==*) = mkEq

-- |Inequality lifted into Accelerate expressions.
--
(/=*) :: (Elem t, IsScalar t) => Exp t -> Exp t -> Exp Bool
(/=*) = mkNEq

-- compare :: a -> a -> Ordering  -- we have no enumerations at the moment
-- compare = ...

-- |Smaller-than lifted into Accelerate expressions.
--
(<*) :: (Elem t, IsScalar t) => Exp t -> Exp t -> Exp Bool
(<*)  = mkLt

-- |Greater-or-equal lifted into Accelerate expressions.
--
(>=*) :: (Elem t, IsScalar t) => Exp t -> Exp t -> Exp Bool
(>=*) = mkGtEq

-- |Greater-than lifted into Accelerate expressions.
--
(>*) :: (Elem t, IsScalar t) => Exp t -> Exp t -> Exp Bool
(>*)  = mkGt

-- |Smaller-or-equal lifted into Accelerate expressions.
--
(<=*) :: (Elem t, IsScalar t) => Exp t -> Exp t -> Exp Bool
(<=*) = mkLtEq

-- |Determine the maximum of two scalars.
--
max :: (Elem t, IsScalar t) => Exp t -> Exp t -> Exp t
max = mkMax

-- |Determine the minimum of two scalars.
--
min :: (Elem t, IsScalar t) => Exp t -> Exp t -> Exp t
min = mkMin


-- Non-overloaded standard functions, where we need other signatures
-- -----------------------------------------------------------------

-- |Conjunction
--
infixr 3 &&*
(&&*) :: Exp Bool -> Exp Bool -> Exp Bool
(&&*) = mkLAnd

-- |Disjunction
--
infixr 2 ||*
(||*) :: Exp Bool -> Exp Bool -> Exp Bool
(||*) = mkLOr

-- |Negation
--
not :: Exp Bool -> Exp Bool
not = mkLNot


-- Conversions
-- -----------

-- |Convert a Boolean value to an 'Int', where 'False' turns into '0' and 'True'
-- into '1'.
-- 
boolToInt :: Exp Bool -> Exp Int
boolToInt = mkBoolToInt

-- |Convert an Int to a Float
intToFloat :: Exp Int -> Exp Float
intToFloat = mkIntFloat

-- |Round Float to Int
roundFloatToInt :: Exp Float -> Exp Int
roundFloatToInt = mkRoundFloatInt

-- |Truncate Float to Int
truncateFloatToInt :: Exp Float -> Exp Int
truncateFloatToInt = mkTruncFloatInt


-- Constants
-- ---------

-- |Magic value identifying elements that are ignored in a forward permutation
--
ignore :: Ix dim => Exp dim
ignore = constant Sugar.ignore