{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE LambdaCase            #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TemplateHaskell       #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.AST
-- Copyright   : [2008..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- /Scalar versus collective operations/
--
-- The embedded array processing language is a two-level language.  It
-- combines a language of scalar expressions and functions with a language of
-- collective array operations.  Scalar expressions are used to compute
-- arguments for collective operations and scalar functions are used to
-- parametrise higher-order, collective array operations.  The two-level
-- structure, in particular, ensures that collective operations cannot be
-- parametrised with collective operations; hence, we are following a flat
-- data-parallel model.  The collective operations manipulate
-- multi-dimensional arrays whose shape is explicitly tracked in their types.
-- In fact, collective operations cannot produce any values other than
-- multi-dimensional arrays; when they yield a scalar, this is in the form of
-- a 0-dimensional, singleton array.  Similarly, scalar expression can -as
-- their name indicates- only produce tuples of scalar, but not arrays.
--
-- There are, however, two expression forms that take arrays as arguments.  As
-- a result scalar and array expressions are recursively dependent.  As we
-- cannot and don't want to compute arrays in the middle of scalar
-- computations, array computations will always be hoisted out of scalar
-- expressions.  So that this is always possible, these array expressions may
-- not contain any free scalar variables.  To express that condition in the
-- type structure, we use separate environments for scalar and array variables.
--
-- /Programs/
--
-- Collective array programs comprise closed expressions of array operations.
-- There is no explicit sharing in the initial AST form, but sharing is
-- introduced subsequently by common subexpression elimination and floating
-- of array computations.
--
-- /Functions/
--
-- The array expression language is first-order and only provides limited
-- control structures to ensure that it can be efficiently executed on
-- compute-acceleration hardware, such as GPUs.  To restrict functions to
-- first-order, we separate function abstraction from the main expression
-- type.  Functions are represented using de Bruijn indices.
--
-- /Parametric and ad-hoc polymorphism/
--
-- The array language features paramatric polymophism (e.g., pairing and
-- projections) as well as ad-hoc polymorphism (e.g., arithmetic operations).
-- All ad-hoc polymorphic constructs include reified dictionaries (c.f.,
-- module 'Types').  Reified dictionaries also ensure that constants
-- (constructor 'Const') are representable on compute acceleration hardware.
--
-- The AST contains both reified dictionaries and type class constraints.
-- Type classes are used for array-related functionality that is uniformly
-- available for all supported types.  In contrast, reified dictionaries are
-- used for functionality that is only available for certain types, such as
-- arithmetic operations.
--

module Data.Array.Accelerate.AST (

  -- * Internal AST
  -- ** Array computations
  Afun, PreAfun, OpenAfun, PreOpenAfun(..),
  Acc, OpenAcc(..), PreOpenAcc(..), Direction(..),
  ALeftHandSide, ArrayVar, ArrayVars,

  -- ** Scalar expressions
  ELeftHandSide, ExpVar, ExpVars, expVars,
  Fun, OpenFun(..),
  Exp, OpenExp(..),
  Boundary(..),
  PrimConst(..),
  PrimFun(..),
  PrimBool,
  PrimMaybe,

  -- ** Extracting type information
  HasArraysR(..), arrayR,
  expType,
  primConstType,
  primFunType,

  -- ** Normal-form
  NFDataAcc,
  rnfOpenAfun, rnfPreOpenAfun,
  rnfOpenAcc, rnfPreOpenAcc,
  rnfALeftHandSide,
  rnfArrayVar,
  rnfOpenFun,
  rnfOpenExp,
  rnfELeftHandSide,
  rnfExpVar,
  rnfBoundary,
  rnfConst,
  rnfPrimConst,
  rnfPrimFun,

  -- ** Template Haskell
  LiftAcc,
  liftPreOpenAfun,
  liftPreOpenAcc,
  liftALeftHandSide,
  liftArrayVar,
  liftOpenFun,
  liftOpenExp,
  liftELeftHandSide,
  liftExpVar,
  liftBoundary,
  liftPrimConst,
  liftPrimFun,

  -- ** Miscellaneous
  showPreAccOp,
  showExpOp,

) where

import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Elt
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Slice
import Data.Array.Accelerate.Representation.Stencil
import Data.Array.Accelerate.Representation.Tag
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Representation.Vec
import Data.Array.Accelerate.Sugar.Foreign
import Data.Array.Accelerate.Type
import Data.Primitive.Vec

import Control.DeepSeq
import Data.Kind
import Language.Haskell.TH                                          ( Q, TExp )
import Prelude

import GHC.TypeLits


-- Array expressions
-- -----------------

-- | Function abstraction over parametrised array computations
--
data PreOpenAfun acc aenv t where
  Abody ::                               acc             aenv  t -> PreOpenAfun acc aenv t
  Alam  :: ALeftHandSide a aenv aenv' -> PreOpenAfun acc aenv' t -> PreOpenAfun acc aenv (a -> t)

-- Function abstraction over vanilla open array computations
--
type OpenAfun = PreOpenAfun OpenAcc

-- | Parametrised array-computation function without free array variables
--
type PreAfun acc = PreOpenAfun acc ()

-- | Vanilla array-computation function without free array variables
--
type Afun = OpenAfun ()

-- Vanilla open array computations
--
newtype OpenAcc aenv t = OpenAcc (PreOpenAcc OpenAcc aenv t)

-- | Closed array expression aka an array program
--
type Acc = OpenAcc ()

-- Types for array binders
--
type ALeftHandSide  = LeftHandSide ArrayR
type ArrayVar       = Var ArrayR
type ArrayVars aenv = Vars ArrayR aenv

-- Bool is not a primitive type
type PrimBool    = TAG
type PrimMaybe a = (TAG, ((), a))


-- | Collective array computations parametrised over array variables
-- represented with de Bruijn indices.
--
-- * Scalar functions and expressions embedded in well-formed array
--   computations cannot contain free scalar variable indices. The latter
--   cannot be bound in array computations, and hence, cannot appear in any
--   well-formed program.
--
-- * The let-form is used to represent the sharing discovered by common
--   subexpression elimination as well as to control evaluation order. (We
--   need to hoist array expressions out of scalar expressions---they occur
--   in scalar indexing and in determining an arrays shape.)
--
-- The data type is parameterised over the surface types (not the
-- representation type).
--
-- We use a non-recursive variant parametrised over the recursive closure,
-- to facilitate attribute calculation in the backend.
--
data PreOpenAcc (acc :: Type -> Type -> Type) aenv a where

  -- Local non-recursive binding to represent sharing and demand
  -- explicitly. Note this is an eager binding!
  --
  Alet        :: ALeftHandSide bndArrs aenv aenv'
              -> acc            aenv  bndArrs         -- bound expression
              -> acc            aenv' bodyArrs        -- the bound expression scope
              -> PreOpenAcc acc aenv  bodyArrs

  -- Variable bound by a 'Let', represented by a de Bruijn index
  --
  Avar        :: ArrayVar       aenv (Array sh e)
              -> PreOpenAcc acc aenv (Array sh e)

  -- Tuples of arrays
  --
  Apair       :: acc            aenv as
              -> acc            aenv bs
              -> PreOpenAcc acc aenv (as, bs)

  Anil        :: PreOpenAcc acc aenv ()

  -- Array-function application.
  --
  -- The array function is not closed at the core level because we need access
  -- to free variables introduced by 'run1' style evaluators. See Issue#95.
  --
  Apply       :: ArraysR arrs2
              -> PreOpenAfun acc aenv (arrs1 -> arrs2)
              -> acc             aenv arrs1
              -> PreOpenAcc  acc aenv arrs2

  -- Apply a backend-specific foreign function to an array, with a pure
  -- Accelerate version for use with other backends. The functions must be
  -- closed.
  --
  Aforeign    :: Foreign asm
              => ArraysR bs
              -> asm                   (as -> bs) -- The foreign function for a given backend
              -> PreAfun      acc      (as -> bs) -- Fallback implementation(s)
              -> acc              aenv as         -- Arguments to the function
              -> PreOpenAcc   acc aenv bs

  -- If-then-else for array-level computations
  --
  Acond       :: Exp            aenv PrimBool
              -> acc            aenv arrs
              -> acc            aenv arrs
              -> PreOpenAcc acc aenv arrs

  -- Value-recursion for array-level computations
  --
  Awhile      :: PreOpenAfun acc aenv (arrs -> Scalar PrimBool) -- continue iteration while true
              -> PreOpenAfun acc aenv (arrs -> arrs)            -- function to iterate
              -> acc             aenv arrs                      -- initial value
              -> PreOpenAcc  acc aenv arrs


  -- Array inlet. Triggers (possibly) asynchronous host->device transfer if
  -- necessary.
  --
  Use         :: ArrayR (Array sh e)
              -> Array sh e
              -> PreOpenAcc acc aenv (Array sh e)

  -- Capture a scalar (or a tuple of scalars) in a singleton array
  --
  Unit        :: TypeR e
              -> Exp            aenv e
              -> PreOpenAcc acc aenv (Scalar e)

  -- Change the shape of an array without altering its contents.
  -- Precondition (this may not be checked!):
  --
  -- > dim == size dim'
  --
  Reshape     :: ShapeR sh
              -> Exp            aenv sh                         -- new shape
              -> acc            aenv (Array sh' e)              -- array to be reshaped
              -> PreOpenAcc acc aenv (Array sh e)

  -- Construct a new array by applying a function to each index.
  --
  Generate    :: ArrayR (Array sh e)
              -> Exp            aenv sh                         -- output shape
              -> Fun            aenv (sh -> e)                  -- representation function
              -> PreOpenAcc acc aenv (Array sh e)

  -- Hybrid map/backpermute, where we separate the index and value
  -- transformations.
  --
  Transform   :: ArrayR (Array sh' b)
              -> Exp            aenv sh'                        -- dimension of the result
              -> Fun            aenv (sh' -> sh)                -- index permutation function
              -> Fun            aenv (a   -> b)                 -- function to apply at each element
              ->            acc aenv (Array sh  a)              -- source array
              -> PreOpenAcc acc aenv (Array sh' b)

  -- Replicate an array across one or more dimensions as given by the first
  -- argument
  --
  Replicate   :: SliceIndex slix sl co sh                       -- slice type specification
              -> Exp            aenv slix                       -- slice value specification
              -> acc            aenv (Array sl e)               -- data to be replicated
              -> PreOpenAcc acc aenv (Array sh e)

  -- Index a sub-array out of an array; i.e., the dimensions not indexed
  -- are returned whole
  --
  Slice       :: SliceIndex slix sl co sh                       -- slice type specification
              -> acc            aenv (Array sh e)               -- array to be indexed
              -> Exp            aenv slix                       -- slice value specification
              -> PreOpenAcc acc aenv (Array sl e)

  -- Apply the given unary function to all elements of the given array
  --
  Map         :: TypeR e'
              -> Fun            aenv (e -> e')
              -> acc            aenv (Array sh e)
              -> PreOpenAcc acc aenv (Array sh e')

  -- Apply a given binary function pairwise to all elements of the given
  -- arrays. The length of the result is the length of the shorter of the
  -- two argument arrays.
  --
  ZipWith     :: TypeR e3
              -> Fun            aenv (e1 -> e2 -> e3)
              -> acc            aenv (Array sh e1)
              -> acc            aenv (Array sh e2)
              -> PreOpenAcc acc aenv (Array sh e3)

  -- Fold along the innermost dimension of an array with a given
  -- /associative/ function.
  --
  Fold        :: Fun            aenv (e -> e -> e)              -- combination function
              -> Maybe     (Exp aenv e)                         -- default value
              -> acc            aenv (Array (sh, Int) e)        -- folded array
              -> PreOpenAcc acc aenv (Array sh e)

  -- Segmented fold along the innermost dimension of an array with a given
  -- /associative/ function
  --
  FoldSeg     :: IntegralType i
              -> Fun            aenv (e -> e -> e)              -- combination function
              -> Maybe     (Exp aenv e)                         -- default value
              -> acc            aenv (Array (sh, Int) e)        -- folded array
              -> acc            aenv (Segments i)               -- segment descriptor
              -> PreOpenAcc acc aenv (Array (sh, Int) e)

  -- Haskell-style scan of a linear array with a given
  -- /associative/ function and optionally an initial element
  -- (which does not need to be the neutral of the associative operations)
  -- If no initial value is given, this is a scan1
  --
  Scan        :: Direction
              -> Fun            aenv (e -> e -> e)              -- combination function
              -> Maybe     (Exp aenv e)                         -- initial value
              -> acc            aenv (Array (sh, Int) e)
              -> PreOpenAcc acc aenv (Array (sh, Int) e)

  -- Like 'Scan', but produces a rightmost (in case of a left-to-right scan)
  -- fold value and an array with the same length as the input array (the
  -- fold value would be the rightmost element in a Haskell-style scan)
  --
  Scan'       :: Direction
              -> Fun            aenv (e -> e -> e)              -- combination function
              -> Exp            aenv e                          -- initial value
              -> acc            aenv (Array (sh, Int) e)
              -> PreOpenAcc acc aenv (Array (sh, Int) e, Array sh e)

  -- Generalised forward permutation is characterised by a permutation function
  -- that determines for each element of the source array where it should go in
  -- the output. The permutation can be between arrays of varying shape and
  -- dimensionality.
  --
  -- Other characteristics of the permutation function 'f':
  --
  --   1. 'f' is a partial function: if it evaluates to the magic value 'ignore'
  --      (i.e. a tuple of -1 values) then those elements of the domain are
  --      dropped.
  --
  --   2. 'f' is not surjective: positions in the target array need not be
  --      picked up by the permutation function, so the target array must first
  --      be initialised from an array of default values.
  --
  --   3. 'f' is not injective: distinct elements of the domain may map to the
  --      same position in the target array. In this case the combination
  --      function is used to combine elements, which needs to be /associative/
  --      and /commutative/.
  --
  Permute     :: Fun            aenv (e -> e -> e)              -- combination function
              -> acc            aenv (Array sh' e)              -- default values
              -> Fun            aenv (sh -> PrimMaybe sh')      -- permutation function
              -> acc            aenv (Array sh e)               -- source array
              -> PreOpenAcc acc aenv (Array sh' e)

  -- Generalised multi-dimensional backwards permutation; the permutation can
  -- be between arrays of varying shape; the permutation function must be total
  --
  Backpermute :: ShapeR sh'
              -> Exp            aenv sh'                        -- dimensions of the result
              -> Fun            aenv (sh' -> sh)                -- permutation function
              -> acc            aenv (Array sh e)               -- source array
              -> PreOpenAcc acc aenv (Array sh' e)

  -- Map a stencil over an array.  In contrast to 'map', the domain of
  -- a stencil function is an entire /neighbourhood/ of each array element.
  --
  Stencil     :: StencilR sh e stencil
              -> TypeR e'
              -> Fun             aenv (stencil -> e')           -- stencil function
              -> Boundary        aenv (Array sh e)              -- boundary condition
              -> acc             aenv (Array sh e)              -- source array
              -> PreOpenAcc  acc aenv (Array sh e')

  -- Map a binary stencil over an array.
  --
  Stencil2    :: StencilR sh a stencil1
              -> StencilR sh b stencil2
              -> TypeR c
              -> Fun             aenv (stencil1 -> stencil2 -> c) -- stencil function
              -> Boundary        aenv (Array sh a)                -- boundary condition #1
              -> acc             aenv (Array sh a)                -- source array #1
              -> Boundary        aenv (Array sh b)                -- boundary condition #2
              -> acc             aenv (Array sh b)                -- source array #2
              -> PreOpenAcc acc  aenv (Array sh c)


data Direction = LeftToRight | RightToLeft
  deriving Eq


-- | Vanilla boundary condition specification for stencil operations
--
data Boundary aenv t where
  -- Clamp coordinates to the extent of the array
  Clamp     :: Boundary aenv t

  -- Mirror coordinates beyond the array extent
  Mirror    :: Boundary aenv t

  -- Wrap coordinates around on each dimension
  Wrap      :: Boundary aenv t

  -- Use a constant value for outlying coordinates
  Constant  :: e
            -> Boundary aenv (Array sh e)

  -- Apply the given function to outlying coordinates
  Function  :: Fun aenv (sh -> e)
            -> Boundary aenv (Array sh e)


-- Embedded expressions
-- --------------------

-- | Vanilla open function abstraction
--
data OpenFun env aenv t where
  Body ::                             OpenExp env  aenv t -> OpenFun env aenv t
  Lam  :: ELeftHandSide a env env' -> OpenFun env' aenv t -> OpenFun env aenv (a -> t)

-- | Vanilla function without free scalar variables
--
type Fun = OpenFun ()

-- | Vanilla expression without free scalar variables
--
type Exp = OpenExp ()

-- Types for scalar bindings
--
type ELeftHandSide = LeftHandSide ScalarType
type ExpVar        = Var ScalarType
type ExpVars env   = Vars ScalarType env

expVars :: ExpVars env t -> OpenExp env aenv t
expVars TupRunit         = Nil
expVars (TupRsingle var) = Evar var
expVars (TupRpair v1 v2) = expVars v1 `Pair` expVars v2


-- | Vanilla open expressions using de Bruijn indices for variables ranging
-- over tuples of scalars and arrays of tuples. All code, except Cond, is
-- evaluated eagerly. N-tuples are represented as nested pairs.
--
-- The data type is parametrised over the representation type (not the
-- surface types).
--
data OpenExp env aenv t where

  -- Local binding of a scalar expression
  Let           :: ELeftHandSide bnd_t env env'
                -> OpenExp env  aenv bnd_t
                -> OpenExp env' aenv body_t
                -> OpenExp env  aenv body_t

  -- Variable index, ranging only over tuples or scalars
  Evar          :: ExpVar env t
                -> OpenExp env aenv t

  -- Apply a backend-specific foreign function
  Foreign       :: Foreign asm
                => TypeR y
                -> asm    (x -> y)    -- foreign function
                -> Fun () (x -> y)    -- alternate implementation (for other backends)
                -> OpenExp env aenv x
                -> OpenExp env aenv y

  -- Tuples
  Pair          :: OpenExp env aenv t1
                -> OpenExp env aenv t2
                -> OpenExp env aenv (t1, t2)

  Nil           :: OpenExp env aenv ()

  -- SIMD vectors
  VecPack       :: KnownNat n
                => VecR n s tup
                -> OpenExp env aenv tup
                -> OpenExp env aenv (Vec n s)

  VecUnpack     :: KnownNat n
                => VecR n s tup
                -> OpenExp env aenv (Vec n s)
                -> OpenExp env aenv tup

  -- Array indices & shapes
  IndexSlice    :: SliceIndex slix sl co sh
                -> OpenExp env aenv slix
                -> OpenExp env aenv sh
                -> OpenExp env aenv sl

  IndexFull     :: SliceIndex slix sl co sh
                -> OpenExp env aenv slix
                -> OpenExp env aenv sl
                -> OpenExp env aenv sh

  -- Shape and index conversion
  ToIndex       :: ShapeR sh
                -> OpenExp env aenv sh           -- shape of the array
                -> OpenExp env aenv sh           -- index into the array
                -> OpenExp env aenv Int

  FromIndex     :: ShapeR sh
                -> OpenExp env aenv sh           -- shape of the array
                -> OpenExp env aenv Int          -- index into linear representation
                -> OpenExp env aenv sh

  -- Case statement
  Case          :: OpenExp env aenv TAG
                -> [(TAG, OpenExp env aenv b)]      -- list of equations
                -> Maybe (OpenExp env aenv b)       -- default case
                -> OpenExp env aenv b

  -- Conditional expression (non-strict in 2nd and 3rd argument)
  Cond          :: OpenExp env aenv PrimBool
                -> OpenExp env aenv t
                -> OpenExp env aenv t
                -> OpenExp env aenv t

  -- Value recursion
  While         :: OpenFun env aenv (a -> PrimBool) -- continue while true
                -> OpenFun env aenv (a -> a)        -- function to iterate
                -> OpenExp env aenv a               -- initial value
                -> OpenExp env aenv a

  -- Constant values
  Const         :: ScalarType t
                -> t
                -> OpenExp env aenv t

  PrimConst     :: PrimConst t
                -> OpenExp env aenv t

  -- Primitive scalar operations
  PrimApp       :: PrimFun (a -> r)
                -> OpenExp env aenv a
                -> OpenExp env aenv r

  -- Project a single scalar from an array.
  -- The array expression can not contain any free scalar variables.
  Index         :: ArrayVar    aenv (Array dim t)
                -> OpenExp env aenv dim
                -> OpenExp env aenv t

  LinearIndex   :: ArrayVar    aenv (Array dim t)
                -> OpenExp env aenv Int
                -> OpenExp env aenv t

  -- Array shape.
  -- The array expression can not contain any free scalar variables.
  Shape         :: ArrayVar    aenv (Array dim e)
                -> OpenExp env aenv dim

  -- Number of elements of an array given its shape
  ShapeSize     :: ShapeR dim
                -> OpenExp env aenv dim
                -> OpenExp env aenv Int

  -- Unsafe operations (may fail or result in undefined behaviour)
  -- An unspecified bit pattern
  Undef         :: ScalarType t
                -> OpenExp env aenv t

  -- Reinterpret the bits of a value as a different type
  Coerce        :: BitSizeEq a b
                => ScalarType a
                -> ScalarType b
                -> OpenExp env aenv a
                -> OpenExp env aenv b

-- |Primitive constant values
--
data PrimConst ty where

  -- constants from Bounded
  PrimMinBound  :: BoundedType a -> PrimConst a
  PrimMaxBound  :: BoundedType a -> PrimConst a

  -- constant from Floating
  PrimPi        :: FloatingType a -> PrimConst a


-- |Primitive scalar operations
--
data PrimFun sig where

  -- operators from Num
  PrimAdd  :: NumType a -> PrimFun ((a, a) -> a)
  PrimSub  :: NumType a -> PrimFun ((a, a) -> a)
  PrimMul  :: NumType a -> PrimFun ((a, a) -> a)
  PrimNeg  :: NumType a -> PrimFun (a      -> a)
  PrimAbs  :: NumType a -> PrimFun (a      -> a)
  PrimSig  :: NumType a -> PrimFun (a      -> a)

  -- operators from Integral
  PrimQuot     :: IntegralType a -> PrimFun ((a, a)   -> a)
  PrimRem      :: IntegralType a -> PrimFun ((a, a)   -> a)
  PrimQuotRem  :: IntegralType a -> PrimFun ((a, a)   -> (a, a))
  PrimIDiv     :: IntegralType a -> PrimFun ((a, a)   -> a)
  PrimMod      :: IntegralType a -> PrimFun ((a, a)   -> a)
  PrimDivMod   :: IntegralType a -> PrimFun ((a, a)   -> (a, a))

  -- operators from Bits & FiniteBits
  PrimBAnd               :: IntegralType a -> PrimFun ((a, a)   -> a)
  PrimBOr                :: IntegralType a -> PrimFun ((a, a)   -> a)
  PrimBXor               :: IntegralType a -> PrimFun ((a, a)   -> a)
  PrimBNot               :: IntegralType a -> PrimFun (a        -> a)
  PrimBShiftL            :: IntegralType a -> PrimFun ((a, Int) -> a)
  PrimBShiftR            :: IntegralType a -> PrimFun ((a, Int) -> a)
  PrimBRotateL           :: IntegralType a -> PrimFun ((a, Int) -> a)
  PrimBRotateR           :: IntegralType a -> PrimFun ((a, Int) -> a)
  PrimPopCount           :: IntegralType a -> PrimFun (a -> Int)
  PrimCountLeadingZeros  :: IntegralType a -> PrimFun (a -> Int)
  PrimCountTrailingZeros :: IntegralType a -> PrimFun (a -> Int)

  -- operators from Fractional and Floating
  PrimFDiv        :: FloatingType a -> PrimFun ((a, a) -> a)
  PrimRecip       :: FloatingType a -> PrimFun (a      -> a)
  PrimSin         :: FloatingType a -> PrimFun (a      -> a)
  PrimCos         :: FloatingType a -> PrimFun (a      -> a)
  PrimTan         :: FloatingType a -> PrimFun (a      -> a)
  PrimAsin        :: FloatingType a -> PrimFun (a      -> a)
  PrimAcos        :: FloatingType a -> PrimFun (a      -> a)
  PrimAtan        :: FloatingType a -> PrimFun (a      -> a)
  PrimSinh        :: FloatingType a -> PrimFun (a      -> a)
  PrimCosh        :: FloatingType a -> PrimFun (a      -> a)
  PrimTanh        :: FloatingType a -> PrimFun (a      -> a)
  PrimAsinh       :: FloatingType a -> PrimFun (a      -> a)
  PrimAcosh       :: FloatingType a -> PrimFun (a      -> a)
  PrimAtanh       :: FloatingType a -> PrimFun (a      -> a)
  PrimExpFloating :: FloatingType a -> PrimFun (a      -> a)
  PrimSqrt        :: FloatingType a -> PrimFun (a      -> a)
  PrimLog         :: FloatingType a -> PrimFun (a      -> a)
  PrimFPow        :: FloatingType a -> PrimFun ((a, a) -> a)
  PrimLogBase     :: FloatingType a -> PrimFun ((a, a) -> a)

  -- FIXME: add missing operations from RealFrac & RealFloat

  -- operators from RealFrac
  PrimTruncate :: FloatingType a -> IntegralType b -> PrimFun (a -> b)
  PrimRound    :: FloatingType a -> IntegralType b -> PrimFun (a -> b)
  PrimFloor    :: FloatingType a -> IntegralType b -> PrimFun (a -> b)
  PrimCeiling  :: FloatingType a -> IntegralType b -> PrimFun (a -> b)
  -- PrimProperFraction :: FloatingType a -> IntegralType b -> PrimFun (a -> (b, a))

  -- operators from RealFloat
  PrimAtan2      :: FloatingType a -> PrimFun ((a, a) -> a)
  PrimIsNaN      :: FloatingType a -> PrimFun (a -> PrimBool)
  PrimIsInfinite :: FloatingType a -> PrimFun (a -> PrimBool)

  -- relational and equality operators
  PrimLt   :: SingleType a -> PrimFun ((a, a) -> PrimBool)
  PrimGt   :: SingleType a -> PrimFun ((a, a) -> PrimBool)
  PrimLtEq :: SingleType a -> PrimFun ((a, a) -> PrimBool)
  PrimGtEq :: SingleType a -> PrimFun ((a, a) -> PrimBool)
  PrimEq   :: SingleType a -> PrimFun ((a, a) -> PrimBool)
  PrimNEq  :: SingleType a -> PrimFun ((a, a) -> PrimBool)
  PrimMax  :: SingleType a -> PrimFun ((a, a) -> a)
  PrimMin  :: SingleType a -> PrimFun ((a, a) -> a)

  -- logical operators
  --
  -- Note that these operators are strict in both arguments. That is, the
  -- second argument of PrimLAnd is always evaluated even when the first
  -- argument is false.
  --
  -- We define (surface level) (&&) and (||) using if-then-else to enable
  -- short-circuiting, while (&&!) and (||!) are strict versions of these
  -- operators, which are defined using PrimLAnd and PrimLOr.
  --
  PrimLAnd :: PrimFun ((PrimBool, PrimBool) -> PrimBool)
  PrimLOr  :: PrimFun ((PrimBool, PrimBool) -> PrimBool)
  PrimLNot :: PrimFun (PrimBool             -> PrimBool)

  -- general conversion between types
  PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b)
  PrimToFloating   :: NumType a -> FloatingType b -> PrimFun (a -> b)


-- Type utilities
-- --------------

class HasArraysR f where
  arraysR :: f aenv a -> ArraysR a

instance HasArraysR OpenAcc where
  arraysR (OpenAcc a) = arraysR a

arrayR :: HasArraysR f => f aenv (Array sh e) -> ArrayR (Array sh e)
arrayR a = case arraysR a of
  TupRsingle aR -> aR

instance HasArraysR acc => HasArraysR (PreOpenAcc acc) where
  arraysR (Alet _ _ body)             = arraysR body
  arraysR (Avar (Var aR _))           = TupRsingle aR
  arraysR (Apair as bs)               = TupRpair (arraysR as) (arraysR bs)
  arraysR Anil                        = TupRunit
  arraysR (Apply aR _ _)              = aR
  arraysR (Aforeign r _ _ _)          = r
  arraysR (Acond _ a _)               = arraysR a
  arraysR (Awhile _ (Alam lhs _) _)   = lhsToTupR lhs
  arraysR Awhile{}                    = error "I want my, I want my MTV!"
  arraysR (Use aR _)                  = TupRsingle aR
  arraysR (Unit tR _)                 = arraysRarray ShapeRz tR
  arraysR (Reshape sh _ a)            = let ArrayR _ tR = arrayR a
                                         in arraysRarray sh tR
  arraysR (Generate aR _ _)           = TupRsingle aR
  arraysR (Transform aR _ _ _ _)      = TupRsingle aR
  arraysR (Replicate slice _ a)       = let ArrayR _ tR = arrayR a
                                         in arraysRarray (sliceDomainR slice) tR
  arraysR (Slice slice a _)           = let ArrayR _ tR = arrayR a
                                         in arraysRarray (sliceShapeR slice) tR
  arraysR (Map tR _ a)                = let ArrayR sh _ = arrayR a
                                         in arraysRarray sh tR
  arraysR (ZipWith tR _ a _)          = let ArrayR sh _ = arrayR a
                                         in arraysRarray sh tR
  arraysR (Fold _ _ a)                = let ArrayR (ShapeRsnoc sh) tR = arrayR a
                                         in arraysRarray sh tR
  arraysR (FoldSeg _ _ _ a _)         = arraysR a
  arraysR (Scan _ _ _ a)              = arraysR a
  arraysR (Scan' _ _ _ a)             = let aR@(ArrayR (ShapeRsnoc sh) tR) = arrayR a
                                         in TupRsingle aR `TupRpair` TupRsingle (ArrayR sh tR)
  arraysR (Permute _ a _ _)           = arraysR a
  arraysR (Backpermute sh _ _ a)      = let ArrayR _ tR = arrayR a
                                         in arraysRarray sh tR
  arraysR (Stencil _ tR _ _ a)        = let ArrayR sh _ = arrayR a
                                         in arraysRarray sh tR
  arraysR (Stencil2 _ _ tR _ _ a _ _) = let ArrayR sh _ = arrayR a
                                         in arraysRarray sh tR

expType :: HasCallStack => OpenExp aenv env t -> TypeR t
expType = \case
  Let _ _ body                 -> expType body
  Evar (Var tR _)              -> TupRsingle tR
  Foreign tR _ _ _             -> tR
  Pair e1 e2                   -> TupRpair (expType e1) (expType e2)
  Nil                          -> TupRunit
  VecPack   vecR _             -> TupRsingle $ VectorScalarType $ vecRvector vecR
  VecUnpack vecR _             -> vecRtuple vecR
  IndexSlice si _ _            -> shapeType $ sliceShapeR si
  IndexFull  si _ _            -> shapeType $ sliceDomainR si
  ToIndex{}                    -> TupRsingle scalarTypeInt
  FromIndex shr _ _            -> shapeType shr
  Case _ ((_,e):_) _           -> expType e
  Case _ [] (Just e)           -> expType e
  Case{}                       -> internalError "empty case encountered"
  Cond _ e _                   -> expType e
  While _ (Lam lhs _) _        -> lhsToTupR lhs
  While{}                      -> error "What's the matter, you're running in the shadows"
  Const tR _                   -> TupRsingle tR
  PrimConst c                  -> TupRsingle $ SingleScalarType $ primConstType c
  PrimApp f _                  -> snd $ primFunType f
  Index (Var repr _) _         -> arrayRtype repr
  LinearIndex (Var repr _) _   -> arrayRtype repr
  Shape (Var repr _)           -> shapeType $ arrayRshape repr
  ShapeSize{}                  -> TupRsingle scalarTypeInt
  Undef tR                     -> TupRsingle tR
  Coerce _ tR _                -> TupRsingle tR

primConstType :: PrimConst a -> SingleType a
primConstType = \case
  PrimMinBound t -> bounded t
  PrimMaxBound t -> bounded t
  PrimPi       t -> floating t
  where
    bounded :: BoundedType a -> SingleType a
    bounded (IntegralBoundedType t) = NumSingleType $ IntegralNumType t

    floating :: FloatingType t -> SingleType t
    floating = NumSingleType . FloatingNumType

primFunType :: PrimFun (a -> b) -> (TypeR a, TypeR b)
primFunType = \case
  -- Num
  PrimAdd t                 -> binary' $ num t
  PrimSub t                 -> binary' $ num t
  PrimMul t                 -> binary' $ num t
  PrimNeg t                 -> unary'  $ num t
  PrimAbs t                 -> unary'  $ num t
  PrimSig t                 -> unary'  $ num t

  -- Integral
  PrimQuot t                -> binary' $ integral t
  PrimRem  t                -> binary' $ integral t
  PrimQuotRem t             -> unary' $ integral t `TupRpair` integral t
  PrimIDiv t                -> binary' $ integral t
  PrimMod  t                -> binary' $ integral t
  PrimDivMod t              -> unary' $ integral t `TupRpair` integral t

  -- Bits & FiniteBits
  PrimBAnd t                -> binary' $ integral t
  PrimBOr t                 -> binary' $ integral t
  PrimBXor t                -> binary' $ integral t
  PrimBNot t                -> unary' $ integral t
  PrimBShiftL t             -> (integral t `TupRpair` int, integral t)
  PrimBShiftR t             -> (integral t `TupRpair` int, integral t)
  PrimBRotateL t            -> (integral t `TupRpair` int, integral t)
  PrimBRotateR t            -> (integral t `TupRpair` int, integral t)
  PrimPopCount t            -> unary (integral t) int
  PrimCountLeadingZeros t   -> unary (integral t) int
  PrimCountTrailingZeros t  -> unary (integral t) int

  -- Fractional, Floating
  PrimFDiv t                -> binary' $ floating t
  PrimRecip t               -> unary'  $ floating t
  PrimSin t                 -> unary'  $ floating t
  PrimCos t                 -> unary'  $ floating t
  PrimTan t                 -> unary'  $ floating t
  PrimAsin t                -> unary'  $ floating t
  PrimAcos t                -> unary'  $ floating t
  PrimAtan t                -> unary'  $ floating t
  PrimSinh t                -> unary'  $ floating t
  PrimCosh t                -> unary'  $ floating t
  PrimTanh t                -> unary'  $ floating t
  PrimAsinh t               -> unary'  $ floating t
  PrimAcosh t               -> unary'  $ floating t
  PrimAtanh t               -> unary'  $ floating t
  PrimExpFloating t         -> unary'  $ floating t
  PrimSqrt t                -> unary'  $ floating t
  PrimLog t                 -> unary'  $ floating t
  PrimFPow t                -> binary' $ floating t
  PrimLogBase t             -> binary' $ floating t

  -- RealFrac
  PrimTruncate a b          -> unary (floating a) (integral b)
  PrimRound a b             -> unary (floating a) (integral b)
  PrimFloor a b             -> unary (floating a) (integral b)
  PrimCeiling a b           -> unary (floating a) (integral b)

  -- RealFloat
  PrimAtan2 t               -> binary' $ floating t
  PrimIsNaN t               -> unary (floating t) bool
  PrimIsInfinite t          -> unary (floating t) bool

  -- Relational and equality
  PrimLt t                  -> compare' t
  PrimGt t                  -> compare' t
  PrimLtEq t                -> compare' t
  PrimGtEq t                -> compare' t
  PrimEq t                  -> compare' t
  PrimNEq t                 -> compare' t
  PrimMax t                 -> binary' $ single t
  PrimMin t                 -> binary' $ single t

  -- Logical
  PrimLAnd                  -> binary' bool
  PrimLOr                   -> binary' bool
  PrimLNot                  -> unary' bool

  -- general conversion between types
  PrimFromIntegral a b      -> unary (integral a) (num b)
  PrimToFloating   a b      -> unary (num a) (floating b)

  where
    unary a b  = (a, b)
    unary' a   = unary a a
    binary a b = (a `TupRpair` a, b)
    binary' a  = binary a a
    compare' a = binary (single a) bool

    single   = TupRsingle . SingleScalarType
    num      = TupRsingle . SingleScalarType . NumSingleType
    integral = num . IntegralNumType
    floating = num . FloatingNumType

    bool     = TupRsingle scalarTypeWord8
    int      = TupRsingle scalarTypeInt


-- Normal form data
-- ================

instance NFData (OpenAfun aenv f) where
  rnf = rnfOpenAfun

instance NFData (OpenAcc aenv t) where
  rnf = rnfOpenAcc

instance NFData (OpenExp env aenv t) where
  rnf = rnfOpenExp

instance NFData (OpenFun env aenv t) where
  rnf = rnfOpenFun


type NFDataAcc acc = forall aenv t. acc aenv t -> ()

rnfOpenAfun :: OpenAfun aenv t -> ()
rnfOpenAfun = rnfPreOpenAfun rnfOpenAcc

rnfPreOpenAfun :: NFDataAcc acc -> PreOpenAfun acc aenv t -> ()
rnfPreOpenAfun rnfA (Abody b) = rnfA b
rnfPreOpenAfun rnfA (Alam lhs f) = rnfALeftHandSide lhs `seq` rnfPreOpenAfun rnfA f

rnfOpenAcc :: OpenAcc aenv t -> ()
rnfOpenAcc (OpenAcc pacc) = rnfPreOpenAcc rnfOpenAcc pacc

rnfPreOpenAcc :: forall acc aenv t. HasArraysR acc => NFDataAcc acc -> PreOpenAcc acc aenv t -> ()
rnfPreOpenAcc rnfA pacc =
  let
      rnfAF :: PreOpenAfun acc aenv' t' -> ()
      rnfAF = rnfPreOpenAfun rnfA

      rnfE :: OpenExp env' aenv' t' -> ()
      rnfE = rnfOpenExp

      rnfF :: OpenFun env' aenv' t' -> ()
      rnfF = rnfOpenFun

      rnfB :: ArrayR (Array sh e) -> Boundary aenv' (Array sh e) -> ()
      rnfB = rnfBoundary
  in
  case pacc of
    Alet lhs bnd body         -> rnfALeftHandSide lhs `seq` rnfA bnd `seq` rnfA body
    Avar var                  -> rnfArrayVar var
    Apair as bs               -> rnfA as `seq` rnfA bs
    Anil                      -> ()
    Apply repr afun acc       -> rnfTupR rnfArrayR repr `seq` rnfAF afun `seq` rnfA acc
    Aforeign repr asm afun a  -> rnfTupR rnfArrayR repr `seq` rnf (strForeign asm) `seq` rnfAF afun `seq` rnfA a
    Acond p a1 a2             -> rnfE p `seq` rnfA a1 `seq` rnfA a2
    Awhile p f a              -> rnfAF p `seq` rnfAF f `seq` rnfA a
    Use repr arr              -> rnfArray repr arr
    Unit tp x                 -> rnfTypeR tp `seq` rnfE x
    Reshape shr sh a          -> rnfShapeR shr `seq` rnfE sh `seq` rnfA a
    Generate repr sh f        -> rnfArrayR repr `seq` rnfE sh `seq` rnfF f
    Transform repr sh p f a   -> rnfArrayR repr `seq` rnfE sh `seq` rnfF p `seq` rnfF f `seq` rnfA a
    Replicate slice sh a      -> rnfSliceIndex slice `seq` rnfE sh `seq` rnfA a
    Slice slice a sh          -> rnfSliceIndex slice `seq` rnfE sh `seq` rnfA a
    Map tp f a                -> rnfTypeR tp `seq` rnfF f `seq` rnfA a
    ZipWith tp f a1 a2        -> rnfTypeR tp `seq` rnfF f `seq` rnfA a1 `seq` rnfA a2
    Fold f z a                -> rnfF f `seq` rnfMaybe rnfE z `seq` rnfA a
    FoldSeg i f z a s         -> rnfIntegralType i `seq` rnfF f `seq` rnfMaybe rnfE z `seq` rnfA a `seq` rnfA s
    Scan d f z a              -> d `seq` rnfF f `seq` rnfMaybe rnfE z `seq` rnfA a
    Scan' d f z a             -> d `seq` rnfF f `seq` rnfE z `seq` rnfA a
    Permute f d p a           -> rnfF f `seq` rnfA d `seq` rnfF p `seq` rnfA a
    Backpermute shr sh f a    -> rnfShapeR shr `seq` rnfE sh `seq` rnfF f `seq` rnfA a
    Stencil sr tp f b a       ->
      let
        TupRsingle (ArrayR shr _) = arraysR a
        repr                      = ArrayR shr $ stencilEltR sr
      in rnfStencilR sr `seq` rnfTupR rnfScalarType tp `seq` rnfF f `seq` rnfB repr b  `seq` rnfA a
    Stencil2 sr1 sr2 tp f b1 a1 b2 a2 ->
      let
        TupRsingle (ArrayR shr _) = arraysR a1
        repr1 = ArrayR shr $ stencilEltR sr1
        repr2 = ArrayR shr $ stencilEltR sr2
      in rnfStencilR sr1 `seq` rnfStencilR sr2 `seq` rnfTupR rnfScalarType tp `seq` rnfF f `seq` rnfB repr1 b1 `seq` rnfB repr2 b2 `seq` rnfA a1 `seq` rnfA a2

rnfArrayVar :: ArrayVar aenv a -> ()
rnfArrayVar = rnfVar rnfArrayR

rnfALeftHandSide :: ALeftHandSide arrs aenv aenv' -> ()
rnfALeftHandSide = rnfLeftHandSide rnfArrayR

rnfBoundary :: forall aenv sh e. ArrayR (Array sh e) -> Boundary aenv (Array sh e) -> ()
rnfBoundary _             Clamp        = ()
rnfBoundary _             Mirror       = ()
rnfBoundary _             Wrap         = ()
rnfBoundary (ArrayR _ tR) (Constant c) = rnfConst tR c
rnfBoundary _             (Function f) = rnfOpenFun f

rnfMaybe :: (a -> ()) -> Maybe a -> ()
rnfMaybe _ Nothing  = ()
rnfMaybe f (Just x) = f x

rnfList :: (a -> ()) -> [a] -> ()
rnfList r = go
  where
    go []     = ()
    go (x:xs) = r x `seq` go xs

rnfOpenFun :: OpenFun env aenv t -> ()
rnfOpenFun (Body b)    = rnfOpenExp b
rnfOpenFun (Lam lhs f) = rnfELeftHandSide lhs `seq` rnfOpenFun f

rnfOpenExp :: forall env aenv t. OpenExp env aenv t -> ()
rnfOpenExp topExp =
  let
      rnfF :: OpenFun env' aenv' t' -> ()
      rnfF = rnfOpenFun

      rnfE :: OpenExp env' aenv' t' -> ()
      rnfE = rnfOpenExp
  in
  case topExp of
    Let lhs bnd body          -> rnfELeftHandSide lhs `seq` rnfE bnd `seq` rnfE body
    Evar v                    -> rnfExpVar v
    Foreign tp asm f x        -> rnfTypeR tp `seq` rnf (strForeign asm) `seq` rnfF f `seq` rnfE x
    Const tp c                -> c `seq` rnfScalarType tp -- scalars should have (nf == whnf)
    Undef tp                  -> rnfScalarType tp
    Pair a b                  -> rnfE a `seq` rnfE b
    Nil                       -> ()
    VecPack   vecr e          -> rnfVecR vecr `seq` rnfE e
    VecUnpack vecr e          -> rnfVecR vecr `seq` rnfE e
    IndexSlice slice slix sh  -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sh
    IndexFull slice slix sl   -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sl
    ToIndex shr sh ix         -> rnfShapeR shr `seq` rnfE sh `seq` rnfE ix
    FromIndex shr sh ix       -> rnfShapeR shr `seq` rnfE sh `seq` rnfE ix
    Case e rhs def            -> rnfE e `seq` rnfList (\(t,c) -> t `seq` rnfE c) rhs `seq` rnfMaybe rnfE def
    Cond p e1 e2              -> rnfE p `seq` rnfE e1 `seq` rnfE e2
    While p f x               -> rnfF p `seq` rnfF f `seq` rnfE x
    PrimConst c               -> rnfPrimConst c
    PrimApp f x               -> rnfPrimFun f `seq` rnfE x
    Index a ix                -> rnfArrayVar a `seq` rnfE ix
    LinearIndex a ix          -> rnfArrayVar a `seq` rnfE ix
    Shape a                   -> rnfArrayVar a
    ShapeSize shr sh          -> rnfShapeR shr `seq` rnfE sh
    Coerce t1 t2 e            -> rnfScalarType t1 `seq` rnfScalarType t2 `seq` rnfE e

rnfExpVar :: ExpVar env t -> ()
rnfExpVar = rnfVar rnfScalarType

rnfELeftHandSide :: ELeftHandSide t env env' -> ()
rnfELeftHandSide= rnfLeftHandSide rnfScalarType

rnfConst :: TypeR t -> t -> ()
rnfConst TupRunit          ()    = ()
rnfConst (TupRsingle t)    !_    = rnfScalarType t  -- scalars should have (nf == whnf)
rnfConst (TupRpair ta tb)  (a,b) = rnfConst ta a `seq` rnfConst tb b

rnfPrimConst :: PrimConst c -> ()
rnfPrimConst (PrimMinBound t) = rnfBoundedType t
rnfPrimConst (PrimMaxBound t) = rnfBoundedType t
rnfPrimConst (PrimPi t)       = rnfFloatingType t

rnfPrimFun :: PrimFun f -> ()
rnfPrimFun (PrimAdd t)                = rnfNumType t
rnfPrimFun (PrimSub t)                = rnfNumType t
rnfPrimFun (PrimMul t)                = rnfNumType t
rnfPrimFun (PrimNeg t)                = rnfNumType t
rnfPrimFun (PrimAbs t)                = rnfNumType t
rnfPrimFun (PrimSig t)                = rnfNumType t
rnfPrimFun (PrimQuot t)               = rnfIntegralType t
rnfPrimFun (PrimRem t)                = rnfIntegralType t
rnfPrimFun (PrimQuotRem t)            = rnfIntegralType t
rnfPrimFun (PrimIDiv t)               = rnfIntegralType t
rnfPrimFun (PrimMod t)                = rnfIntegralType t
rnfPrimFun (PrimDivMod t)             = rnfIntegralType t
rnfPrimFun (PrimBAnd t)               = rnfIntegralType t
rnfPrimFun (PrimBOr t)                = rnfIntegralType t
rnfPrimFun (PrimBXor t)               = rnfIntegralType t
rnfPrimFun (PrimBNot t)               = rnfIntegralType t
rnfPrimFun (PrimBShiftL t)            = rnfIntegralType t
rnfPrimFun (PrimBShiftR t)            = rnfIntegralType t
rnfPrimFun (PrimBRotateL t)           = rnfIntegralType t
rnfPrimFun (PrimBRotateR t)           = rnfIntegralType t
rnfPrimFun (PrimPopCount t)           = rnfIntegralType t
rnfPrimFun (PrimCountLeadingZeros t)  = rnfIntegralType t
rnfPrimFun (PrimCountTrailingZeros t) = rnfIntegralType t
rnfPrimFun (PrimFDiv t)               = rnfFloatingType t
rnfPrimFun (PrimRecip t)              = rnfFloatingType t
rnfPrimFun (PrimSin t)                = rnfFloatingType t
rnfPrimFun (PrimCos t)                = rnfFloatingType t
rnfPrimFun (PrimTan t)                = rnfFloatingType t
rnfPrimFun (PrimAsin t)               = rnfFloatingType t
rnfPrimFun (PrimAcos t)               = rnfFloatingType t
rnfPrimFun (PrimAtan t)               = rnfFloatingType t
rnfPrimFun (PrimSinh t)               = rnfFloatingType t
rnfPrimFun (PrimCosh t)               = rnfFloatingType t
rnfPrimFun (PrimTanh t)               = rnfFloatingType t
rnfPrimFun (PrimAsinh t)              = rnfFloatingType t
rnfPrimFun (PrimAcosh t)              = rnfFloatingType t
rnfPrimFun (PrimAtanh t)              = rnfFloatingType t
rnfPrimFun (PrimExpFloating t)        = rnfFloatingType t
rnfPrimFun (PrimSqrt t)               = rnfFloatingType t
rnfPrimFun (PrimLog t)                = rnfFloatingType t
rnfPrimFun (PrimFPow t)               = rnfFloatingType t
rnfPrimFun (PrimLogBase t)            = rnfFloatingType t
rnfPrimFun (PrimTruncate f i)         = rnfFloatingType f `seq` rnfIntegralType i
rnfPrimFun (PrimRound f i)            = rnfFloatingType f `seq` rnfIntegralType i
rnfPrimFun (PrimFloor f i)            = rnfFloatingType f `seq` rnfIntegralType i
rnfPrimFun (PrimCeiling f i)          = rnfFloatingType f `seq` rnfIntegralType i
rnfPrimFun (PrimIsNaN t)              = rnfFloatingType t
rnfPrimFun (PrimIsInfinite t)         = rnfFloatingType t
rnfPrimFun (PrimAtan2 t)              = rnfFloatingType t
rnfPrimFun (PrimLt t)                 = rnfSingleType t
rnfPrimFun (PrimGt t)                 = rnfSingleType t
rnfPrimFun (PrimLtEq t)               = rnfSingleType t
rnfPrimFun (PrimGtEq t)               = rnfSingleType t
rnfPrimFun (PrimEq t)                 = rnfSingleType t
rnfPrimFun (PrimNEq t)                = rnfSingleType t
rnfPrimFun (PrimMax t)                = rnfSingleType t
rnfPrimFun (PrimMin t)                = rnfSingleType t
rnfPrimFun PrimLAnd                   = ()
rnfPrimFun PrimLOr                    = ()
rnfPrimFun PrimLNot                   = ()
rnfPrimFun (PrimFromIntegral i n)     = rnfIntegralType i `seq` rnfNumType n
rnfPrimFun (PrimToFloating n f)       = rnfNumType n `seq` rnfFloatingType f


-- Template Haskell
-- ================

type LiftAcc acc = forall aenv a. acc aenv a -> Q (TExp (acc aenv a))

liftPreOpenAfun :: LiftAcc acc -> PreOpenAfun acc aenv t -> Q (TExp (PreOpenAfun acc aenv t))
liftPreOpenAfun liftA (Alam lhs f) = [|| Alam $$(liftALeftHandSide lhs) $$(liftPreOpenAfun liftA f) ||]
liftPreOpenAfun liftA (Abody b)    = [|| Abody $$(liftA b) ||]

liftPreOpenAcc
    :: forall acc aenv a.
       HasArraysR acc
    => LiftAcc acc
    -> PreOpenAcc acc aenv a
    -> Q (TExp (PreOpenAcc acc aenv a))
liftPreOpenAcc liftA pacc =
  let
      liftE :: OpenExp env aenv t -> Q (TExp (OpenExp env aenv t))
      liftE = liftOpenExp

      liftF :: OpenFun env aenv t -> Q (TExp (OpenFun env aenv t))
      liftF = liftOpenFun

      liftAF :: PreOpenAfun acc aenv f -> Q (TExp (PreOpenAfun acc aenv f))
      liftAF = liftPreOpenAfun liftA

      liftB :: ArrayR (Array sh e) -> Boundary aenv (Array sh e) -> Q (TExp (Boundary aenv (Array sh e)))
      liftB = liftBoundary

  in
  case pacc of
    Alet lhs bnd body         -> [|| Alet $$(liftALeftHandSide lhs) $$(liftA bnd) $$(liftA body) ||]
    Avar var                  -> [|| Avar $$(liftArrayVar var) ||]
    Apair as bs               -> [|| Apair $$(liftA as) $$(liftA bs) ||]
    Anil                      -> [|| Anil ||]
    Apply repr f a            -> [|| Apply $$(liftArraysR repr) $$(liftAF f) $$(liftA a) ||]
    Aforeign repr asm f a     -> [|| Aforeign $$(liftArraysR repr) $$(liftForeign asm) $$(liftPreOpenAfun liftA f) $$(liftA a) ||]
    Acond p t e               -> [|| Acond $$(liftE p) $$(liftA t) $$(liftA e) ||]
    Awhile p f a              -> [|| Awhile $$(liftAF p) $$(liftAF f) $$(liftA a) ||]
    Use repr a                -> [|| Use $$(liftArrayR repr) $$(liftArray repr a) ||]
    Unit tp e                 -> [|| Unit $$(liftTypeR tp) $$(liftE e) ||]
    Reshape shr sh a          -> [|| Reshape $$(liftShapeR shr) $$(liftE sh) $$(liftA a) ||]
    Generate repr sh f        -> [|| Generate $$(liftArrayR repr) $$(liftE sh) $$(liftF f) ||]
    Transform repr sh p f a   -> [|| Transform $$(liftArrayR repr) $$(liftE sh) $$(liftF p) $$(liftF f) $$(liftA a) ||]
    Replicate slix sl a       -> [|| Replicate $$(liftSliceIndex slix) $$(liftE sl) $$(liftA a) ||]
    Slice slix a sh           -> [|| Slice $$(liftSliceIndex slix) $$(liftA a) $$(liftE sh) ||]
    Map tp f a                -> [|| Map $$(liftTypeR tp) $$(liftF f) $$(liftA a) ||]
    ZipWith tp f a b          -> [|| ZipWith $$(liftTypeR tp) $$(liftF f) $$(liftA a) $$(liftA b) ||]
    Fold f z a                -> [|| Fold $$(liftF f) $$(liftMaybe liftE z) $$(liftA a) ||]
    FoldSeg i f z a s         -> [|| FoldSeg $$(liftIntegralType i) $$(liftF f) $$(liftMaybe liftE z) $$(liftA a) $$(liftA s) ||]
    Scan d f z a              -> [|| Scan  $$(liftDirection d) $$(liftF f) $$(liftMaybe liftE z) $$(liftA a) ||]
    Scan' d f z a             -> [|| Scan' $$(liftDirection d) $$(liftF f) $$(liftE z) $$(liftA a) ||]
    Permute f d p a           -> [|| Permute $$(liftF f) $$(liftA d) $$(liftF p) $$(liftA a) ||]
    Backpermute shr sh p a    -> [|| Backpermute $$(liftShapeR shr) $$(liftE sh) $$(liftF p) $$(liftA a) ||]
    Stencil sr tp f b a       ->
      let
        TupRsingle (ArrayR shr _) = arraysR a
        repr = ArrayR shr $ stencilEltR sr
      in [|| Stencil $$(liftStencilR sr) $$(liftTypeR tp) $$(liftF f) $$(liftB repr b) $$(liftA a) ||]
    Stencil2 sr1 sr2 tp f b1 a1 b2 a2 ->
      let
        TupRsingle (ArrayR shr _) = arraysR a1
        repr1 = ArrayR shr $ stencilEltR sr1
        repr2 = ArrayR shr $ stencilEltR sr2
      in [|| Stencil2 $$(liftStencilR sr1) $$(liftStencilR sr2) $$(liftTypeR tp) $$(liftF f) $$(liftB repr1 b1) $$(liftA a1) $$(liftB repr2 b2) $$(liftA a2) ||]


liftALeftHandSide :: ALeftHandSide arrs aenv aenv' -> Q (TExp (ALeftHandSide arrs aenv aenv'))
liftALeftHandSide = liftLeftHandSide liftArrayR

liftArrayVar :: ArrayVar aenv a -> Q (TExp (ArrayVar aenv a))
liftArrayVar = liftVar liftArrayR

liftDirection :: Direction -> Q (TExp Direction)
liftDirection LeftToRight = [|| LeftToRight ||]
liftDirection RightToLeft = [|| RightToLeft ||]

liftMaybe :: (a -> Q (TExp a)) -> Maybe a -> Q (TExp (Maybe a))
liftMaybe _ Nothing  = [|| Nothing ||]
liftMaybe f (Just x) = [|| Just $$(f x) ||]

liftList :: (a -> Q (TExp a)) -> [a] -> Q (TExp [a])
liftList _ []     = [|| [] ||]
liftList f (x:xs) = [|| $$(f x) : $$(liftList f xs) ||]

liftOpenFun
    :: OpenFun env aenv t
    -> Q (TExp (OpenFun env aenv t))
liftOpenFun (Lam lhs f)  = [|| Lam $$(liftELeftHandSide lhs) $$(liftOpenFun f) ||]
liftOpenFun (Body b)     = [|| Body $$(liftOpenExp b) ||]

liftOpenExp
    :: forall env aenv t.
       OpenExp env aenv t
    -> Q (TExp (OpenExp env aenv t))
liftOpenExp pexp =
  let
      liftE :: OpenExp env aenv e -> Q (TExp (OpenExp env aenv e))
      liftE = liftOpenExp

      liftF :: OpenFun env aenv f -> Q (TExp (OpenFun env aenv f))
      liftF = liftOpenFun
  in
  case pexp of
    Let lhs bnd body          -> [|| Let $$(liftELeftHandSide lhs) $$(liftOpenExp bnd) $$(liftOpenExp body) ||]
    Evar var                  -> [|| Evar $$(liftExpVar var) ||]
    Foreign repr asm f x      -> [|| Foreign $$(liftTypeR repr) $$(liftForeign asm) $$(liftOpenFun f) $$(liftE x) ||]
    Const tp c                -> [|| Const $$(liftScalarType tp) $$(liftElt (TupRsingle tp) c) ||]
    Undef tp                  -> [|| Undef $$(liftScalarType tp) ||]
    Pair a b                  -> [|| Pair $$(liftE a) $$(liftE b) ||]
    Nil                       -> [|| Nil ||]
    VecPack   vecr e          -> [|| VecPack   $$(liftVecR vecr) $$(liftE e) ||]
    VecUnpack vecr e          -> [|| VecUnpack $$(liftVecR vecr) $$(liftE e) ||]
    IndexSlice slice slix sh  -> [|| IndexSlice $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sh) ||]
    IndexFull slice slix sl   -> [|| IndexFull $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sl) ||]
    ToIndex shr sh ix         -> [|| ToIndex $$(liftShapeR shr) $$(liftE sh) $$(liftE ix) ||]
    FromIndex shr sh ix       -> [|| FromIndex $$(liftShapeR shr) $$(liftE sh) $$(liftE ix) ||]
    Case p rhs def            -> [|| Case $$(liftE p) $$(liftList (\(t,c) -> [|| (t, $$(liftE c)) ||]) rhs) $$(liftMaybe liftE def) ||]
    Cond p t e                -> [|| Cond $$(liftE p) $$(liftE t) $$(liftE e) ||]
    While p f x               -> [|| While $$(liftF p) $$(liftF f) $$(liftE x) ||]
    PrimConst t               -> [|| PrimConst $$(liftPrimConst t) ||]
    PrimApp f x               -> [|| PrimApp $$(liftPrimFun f) $$(liftE x) ||]
    Index a ix                -> [|| Index $$(liftArrayVar a) $$(liftE ix) ||]
    LinearIndex a ix          -> [|| LinearIndex $$(liftArrayVar a) $$(liftE ix) ||]
    Shape a                   -> [|| Shape $$(liftArrayVar a) ||]
    ShapeSize shr ix          -> [|| ShapeSize $$(liftShapeR shr) $$(liftE ix) ||]
    Coerce t1 t2 e            -> [|| Coerce $$(liftScalarType t1) $$(liftScalarType t2) $$(liftE e) ||]

liftELeftHandSide :: ELeftHandSide t env env' -> Q (TExp (ELeftHandSide t env env'))
liftELeftHandSide = liftLeftHandSide liftScalarType

liftExpVar :: ExpVar env t -> Q (TExp (ExpVar env t))
liftExpVar = liftVar liftScalarType

liftBoundary
    :: forall aenv sh e.
       ArrayR (Array sh e)
    -> Boundary aenv (Array sh e)
    -> Q (TExp (Boundary aenv (Array sh e)))
liftBoundary _             Clamp        = [|| Clamp ||]
liftBoundary _             Mirror       = [|| Mirror ||]
liftBoundary _             Wrap         = [|| Wrap ||]
liftBoundary (ArrayR _ tp) (Constant v) = [|| Constant $$(liftElt tp v) ||]
liftBoundary _             (Function f) = [|| Function $$(liftOpenFun f) ||]

liftPrimConst :: PrimConst c -> Q (TExp (PrimConst c))
liftPrimConst (PrimMinBound t) = [|| PrimMinBound $$(liftBoundedType t) ||]
liftPrimConst (PrimMaxBound t) = [|| PrimMaxBound $$(liftBoundedType t) ||]
liftPrimConst (PrimPi t)       = [|| PrimPi $$(liftFloatingType t) ||]

liftPrimFun :: PrimFun f -> Q (TExp (PrimFun f))
liftPrimFun (PrimAdd t)                = [|| PrimAdd $$(liftNumType t) ||]
liftPrimFun (PrimSub t)                = [|| PrimSub $$(liftNumType t) ||]
liftPrimFun (PrimMul t)                = [|| PrimMul $$(liftNumType t) ||]
liftPrimFun (PrimNeg t)                = [|| PrimNeg $$(liftNumType t) ||]
liftPrimFun (PrimAbs t)                = [|| PrimAbs $$(liftNumType t) ||]
liftPrimFun (PrimSig t)                = [|| PrimSig $$(liftNumType t) ||]
liftPrimFun (PrimQuot t)               = [|| PrimQuot $$(liftIntegralType t) ||]
liftPrimFun (PrimRem t)                = [|| PrimRem $$(liftIntegralType t) ||]
liftPrimFun (PrimQuotRem t)            = [|| PrimQuotRem $$(liftIntegralType t) ||]
liftPrimFun (PrimIDiv t)               = [|| PrimIDiv $$(liftIntegralType t) ||]
liftPrimFun (PrimMod t)                = [|| PrimMod $$(liftIntegralType t) ||]
liftPrimFun (PrimDivMod t)             = [|| PrimDivMod $$(liftIntegralType t) ||]
liftPrimFun (PrimBAnd t)               = [|| PrimBAnd $$(liftIntegralType t) ||]
liftPrimFun (PrimBOr t)                = [|| PrimBOr $$(liftIntegralType t) ||]
liftPrimFun (PrimBXor t)               = [|| PrimBXor $$(liftIntegralType t) ||]
liftPrimFun (PrimBNot t)               = [|| PrimBNot $$(liftIntegralType t) ||]
liftPrimFun (PrimBShiftL t)            = [|| PrimBShiftL $$(liftIntegralType t) ||]
liftPrimFun (PrimBShiftR t)            = [|| PrimBShiftR $$(liftIntegralType t) ||]
liftPrimFun (PrimBRotateL t)           = [|| PrimBRotateL $$(liftIntegralType t) ||]
liftPrimFun (PrimBRotateR t)           = [|| PrimBRotateR $$(liftIntegralType t) ||]
liftPrimFun (PrimPopCount t)           = [|| PrimPopCount $$(liftIntegralType t) ||]
liftPrimFun (PrimCountLeadingZeros t)  = [|| PrimCountLeadingZeros $$(liftIntegralType t) ||]
liftPrimFun (PrimCountTrailingZeros t) = [|| PrimCountTrailingZeros $$(liftIntegralType t) ||]
liftPrimFun (PrimFDiv t)               = [|| PrimFDiv $$(liftFloatingType t) ||]
liftPrimFun (PrimRecip t)              = [|| PrimRecip $$(liftFloatingType t) ||]
liftPrimFun (PrimSin t)                = [|| PrimSin $$(liftFloatingType t) ||]
liftPrimFun (PrimCos t)                = [|| PrimCos $$(liftFloatingType t) ||]
liftPrimFun (PrimTan t)                = [|| PrimTan $$(liftFloatingType t) ||]
liftPrimFun (PrimAsin t)               = [|| PrimAsin $$(liftFloatingType t) ||]
liftPrimFun (PrimAcos t)               = [|| PrimAcos $$(liftFloatingType t) ||]
liftPrimFun (PrimAtan t)               = [|| PrimAtan $$(liftFloatingType t) ||]
liftPrimFun (PrimSinh t)               = [|| PrimSinh $$(liftFloatingType t) ||]
liftPrimFun (PrimCosh t)               = [|| PrimCosh $$(liftFloatingType t) ||]
liftPrimFun (PrimTanh t)               = [|| PrimTanh $$(liftFloatingType t) ||]
liftPrimFun (PrimAsinh t)              = [|| PrimAsinh $$(liftFloatingType t) ||]
liftPrimFun (PrimAcosh t)              = [|| PrimAcosh $$(liftFloatingType t) ||]
liftPrimFun (PrimAtanh t)              = [|| PrimAtanh $$(liftFloatingType t) ||]
liftPrimFun (PrimExpFloating t)        = [|| PrimExpFloating $$(liftFloatingType t) ||]
liftPrimFun (PrimSqrt t)               = [|| PrimSqrt $$(liftFloatingType t) ||]
liftPrimFun (PrimLog t)                = [|| PrimLog $$(liftFloatingType t) ||]
liftPrimFun (PrimFPow t)               = [|| PrimFPow $$(liftFloatingType t) ||]
liftPrimFun (PrimLogBase t)            = [|| PrimLogBase $$(liftFloatingType t) ||]
liftPrimFun (PrimTruncate ta tb)       = [|| PrimTruncate $$(liftFloatingType ta) $$(liftIntegralType tb) ||]
liftPrimFun (PrimRound ta tb)          = [|| PrimRound $$(liftFloatingType ta) $$(liftIntegralType tb) ||]
liftPrimFun (PrimFloor ta tb)          = [|| PrimFloor $$(liftFloatingType ta) $$(liftIntegralType tb) ||]
liftPrimFun (PrimCeiling ta tb)        = [|| PrimCeiling $$(liftFloatingType ta) $$(liftIntegralType tb) ||]
liftPrimFun (PrimIsNaN t)              = [|| PrimIsNaN $$(liftFloatingType t) ||]
liftPrimFun (PrimIsInfinite t)         = [|| PrimIsInfinite $$(liftFloatingType t) ||]
liftPrimFun (PrimAtan2 t)              = [|| PrimAtan2 $$(liftFloatingType t) ||]
liftPrimFun (PrimLt t)                 = [|| PrimLt $$(liftSingleType t) ||]
liftPrimFun (PrimGt t)                 = [|| PrimGt $$(liftSingleType t) ||]
liftPrimFun (PrimLtEq t)               = [|| PrimLtEq $$(liftSingleType t) ||]
liftPrimFun (PrimGtEq t)               = [|| PrimGtEq $$(liftSingleType t) ||]
liftPrimFun (PrimEq t)                 = [|| PrimEq $$(liftSingleType t) ||]
liftPrimFun (PrimNEq t)                = [|| PrimNEq $$(liftSingleType t) ||]
liftPrimFun (PrimMax t)                = [|| PrimMax $$(liftSingleType t) ||]
liftPrimFun (PrimMin t)                = [|| PrimMin $$(liftSingleType t) ||]
liftPrimFun PrimLAnd                   = [|| PrimLAnd ||]
liftPrimFun PrimLOr                    = [|| PrimLOr ||]
liftPrimFun PrimLNot                   = [|| PrimLNot ||]
liftPrimFun (PrimFromIntegral ta tb)   = [|| PrimFromIntegral $$(liftIntegralType ta) $$(liftNumType tb) ||]
liftPrimFun (PrimToFloating ta tb)     = [|| PrimToFloating $$(liftNumType ta) $$(liftFloatingType tb) ||]


showPreAccOp :: forall acc aenv arrs. PreOpenAcc acc aenv arrs -> String
showPreAccOp Alet{}              = "Alet"
showPreAccOp (Avar (Var _ ix))   = "Avar a" ++ show (idxToInt ix)
showPreAccOp (Use aR a)          = "Use " ++ showArrayShort 5 (showsElt (arrayRtype aR)) aR a
showPreAccOp Apply{}             = "Apply"
showPreAccOp Aforeign{}          = "Aforeign"
showPreAccOp Acond{}             = "Acond"
showPreAccOp Awhile{}            = "Awhile"
showPreAccOp Apair{}             = "Apair"
showPreAccOp Anil                = "Anil"
showPreAccOp Unit{}              = "Unit"
showPreAccOp Generate{}          = "Generate"
showPreAccOp Transform{}         = "Transform"
showPreAccOp Reshape{}           = "Reshape"
showPreAccOp Replicate{}         = "Replicate"
showPreAccOp Slice{}             = "Slice"
showPreAccOp Map{}               = "Map"
showPreAccOp ZipWith{}           = "ZipWith"
showPreAccOp (Fold _ z _)        = "Fold" ++ maybe "1" (const "") z
showPreAccOp (FoldSeg _ _ z _ _) = "Fold" ++ maybe "1" (const "") z ++ "Seg"
showPreAccOp (Scan d _ z _)      = "Scan" ++ showsDirection d (maybe "1" (const "") z)
showPreAccOp (Scan' d _ _ _)     = "Scan" ++ showsDirection d "'"
showPreAccOp Permute{}           = "Permute"
showPreAccOp Backpermute{}       = "Backpermute"
showPreAccOp Stencil{}           = "Stencil"
showPreAccOp Stencil2{}          = "Stencil2"

showsDirection :: Direction -> ShowS
showsDirection LeftToRight = ('l':)
showsDirection RightToLeft = ('r':)

showExpOp :: forall aenv env t. OpenExp aenv env t -> String
showExpOp Let{}             = "Let"
showExpOp (Evar (Var _ ix)) = "Var x" ++ show (idxToInt ix)
showExpOp (Const tp c)      = "Const " ++ showElt (TupRsingle tp) c
showExpOp Undef{}           = "Undef"
showExpOp Foreign{}         = "Foreign"
showExpOp Pair{}            = "Pair"
showExpOp Nil{}             = "Nil"
showExpOp VecPack{}         = "VecPack"
showExpOp VecUnpack{}       = "VecUnpack"
showExpOp IndexSlice{}      = "IndexSlice"
showExpOp IndexFull{}       = "IndexFull"
showExpOp ToIndex{}         = "ToIndex"
showExpOp FromIndex{}       = "FromIndex"
showExpOp Case{}            = "Case"
showExpOp Cond{}            = "Cond"
showExpOp While{}           = "While"
showExpOp PrimConst{}       = "PrimConst"
showExpOp PrimApp{}         = "PrimApp"
showExpOp Index{}           = "Index"
showExpOp LinearIndex{}     = "LinearIndex"
showExpOp Shape{}           = "Shape"
showExpOp ShapeSize{}       = "ShapeSize"
showExpOp Coerce{}          = "Coerce"