{-# LANGUAGE GADTs               #-}
{-# LANGUAGE PatternGuards       #-}
{-# LANGUAGE QuasiQuotes         #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# OPTIONS -fno-warn-name-shadowing #-}
-- |
-- Module      : Data.Array.Accelerate.CUDA.CodeGen
-- Copyright   : [2008..2014] Manuel M T Chakravarty, Gabriele Keller
--               [2009..2014] Trevor L. McDonell
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <tmcdonell@cse.unsw.edu.au>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.CUDA.CodeGen (

  CUTranslSkel, codegenAcc,

) where

-- libraries
import Prelude                                                  hiding ( id, exp, replicate )
import Control.Applicative                                      ( (<$>), (<*>) )
import Control.Monad.State.Strict
import Data.Loc
import Data.Char
import Data.HashSet                                             ( HashSet )
import Foreign.CUDA.Analysis
import Language.C.Quote.CUDA
import qualified Language.C                                     as C
import qualified Data.HashSet                                   as Set

-- friends
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Tuple
import Data.Array.Accelerate.Trafo
import Data.Array.Accelerate.Pretty                             ()
import Data.Array.Accelerate.Analysis.Shape
import Data.Array.Accelerate.Array.Sugar                        ( Array, Shape, Elt, EltRepr )
import Data.Array.Accelerate.Array.Representation               ( SliceIndex(..) )
import qualified Data.Array.Accelerate.Array.Sugar              as Sugar
import qualified Data.Array.Accelerate.Analysis.Type            as Sugar

import Data.Array.Accelerate.CUDA.AST                           hiding ( Val(..), prj )
import Data.Array.Accelerate.CUDA.CodeGen.Base
import Data.Array.Accelerate.CUDA.CodeGen.Type
import Data.Array.Accelerate.CUDA.CodeGen.Monad
import Data.Array.Accelerate.CUDA.CodeGen.Mapping
import Data.Array.Accelerate.CUDA.CodeGen.IndexSpace
import Data.Array.Accelerate.CUDA.CodeGen.PrefixSum
import Data.Array.Accelerate.CUDA.CodeGen.Reduction
import Data.Array.Accelerate.CUDA.CodeGen.Stencil
import Data.Array.Accelerate.CUDA.Foreign.Import                ( canExecuteExp )


-- Local environments
--
data Val env where
  Empty ::                       Val ()
  Push  :: Val env -> [C.Exp] -> Val (env, s)

prj :: Idx env t -> Val env -> [C.Exp]
prj ZeroIdx      (Push _   v) = v
prj (SuccIdx ix) (Push val _) = prj ix val
prj _            _            = $internalError "prj" "inconsistent valuation"


-- Array expressions
-- -----------------

-- | Instantiate an array computation with a set of concrete function and type
-- definitions to fix the parameters of an algorithmic skeleton. The generated
-- code can then be pretty-printed to file, and compiled to object code
-- executable on the device. This generates a set of __global__ device functions
-- required to compute the given computation node.
--
-- The code generator requires that the only array form allowed within scalar
-- expressions are array variables. The list of array-valued scalar inputs are
-- taken as the environment.
--
-- TODO: include a measure of how much shared memory a kernel requires.
--
codegenAcc :: forall aenv arrs. DeviceProperties -> DelayedOpenAcc aenv arrs -> Gamma aenv -> [ CUTranslSkel aenv arrs ]
codegenAcc _   Delayed{}       _    = $internalError "codegenAcc" "expected manifest array"
codegenAcc dev (Manifest pacc) aenv
  = codegen
  $ case pacc of

      -- Producers
      Map f a                   -> mkMap dev aenv       <$> travF1 f <*> travD a
      Generate _ f              -> mkGenerate dev aenv  <$> travF1 f
      Transform _ p f a         -> mkTransform dev aenv <$> travF1 p <*> travF1 f  <*> travD a
      Backpermute _ p a         -> mkTransform dev aenv <$> travF1 p <*> travF1 id <*> travD a

      -- Consumers
      Fold f z a                -> mkFold  dev aenv     <$> travF2 f <*> travE z  <*> travD a
      Fold1 f a                 -> mkFold1 dev aenv     <$> travF2 f <*> travD a
      FoldSeg f z a s           -> mkFoldSeg dev aenv   <$> travF2 f <*> travE z  <*> travD a <*> travD s
      Fold1Seg f a s            -> mkFold1Seg dev aenv  <$> travF2 f <*> travD a  <*> travD s
      Scanl f z a               -> mkScanl dev aenv     <$> travF2 f <*> travE z  <*> travD a
      Scanr f z a               -> mkScanr dev aenv     <$> travF2 f <*> travE z  <*> travD a
      Scanl' f z a              -> mkScanl' dev aenv    <$> travF2 f <*> travE z  <*> travD a
      Scanr' f z a              -> mkScanr' dev aenv    <$> travF2 f <*> travE z  <*> travD a
      Scanl1 f a                -> mkScanl1 dev aenv    <$> travF2 f <*> travD a
      Scanr1 f a                -> mkScanr1 dev aenv    <$> travF2 f <*> travD a
      Permute f _ p a           -> mkPermute dev aenv   <$> travF2 f <*> travF1 p <*> travD a
      Stencil f b a             -> mkStencil dev aenv   <$> travF1 f <*> travB a b
      Stencil2 f b1 a1 b2 a2    -> mkStencil2 dev aenv  <$> travF2 f <*> travB a1 b1 <*> travB a2 b2

      -- Non-computation forms -> sadness
      Alet{}                    -> unexpectedError
      Avar{}                    -> unexpectedError
      Apply{}                   -> unexpectedError
      Acond{}                   -> unexpectedError
      Awhile{}                  -> unexpectedError
      Atuple{}                  -> unexpectedError
      Aprj{}                    -> unexpectedError
      Use{}                     -> unexpectedError
      Unit{}                    -> unexpectedError
      Aforeign{}                -> unexpectedError
      Reshape{}                 -> unexpectedError

      Replicate{}               -> fusionError
      Slice{}                   -> fusionError
      ZipWith{}                 -> fusionError

  where
    codegen :: CUDA [CUTranslSkel aenv a] -> [CUTranslSkel aenv a]
    codegen cuda =
      let (skeletons, st)                = runCUDA cuda
          addTo (CUTranslSkel name code) =
            CUTranslSkel name (Set.foldr (\h c -> [cedecl| $esc:("#include \"" ++ h ++ "\"") |] : c) code (headers st))
      in
      map addTo skeletons

    id :: Elt a => DelayedFun aenv (a -> a)
    id = Lam (Body (Var ZeroIdx))

    -- code generation for delayed arrays
    travD :: (Shape sh, Elt e) => DelayedOpenAcc aenv (Array sh e) -> CUDA (CUDelayedAcc aenv sh e)
    travD Manifest{}  = $internalError "codegenAcc" "expected delayed array"
    travD Delayed{..} = CUDelayed <$> travE extentD
                                  <*> travF1 indexD
                                  <*> travF1 linearIndexD

    -- scalar code generation
    travF1 :: DelayedFun aenv (a -> b) -> CUDA (CUFun1 aenv (a -> b))
    travF1 = codegenFun1 dev aenv

    travF2 :: DelayedFun aenv (a -> b -> c) -> CUDA (CUFun2 aenv (a -> b -> c))
    travF2 = codegenFun2 dev aenv

    travE :: DelayedExp aenv t -> CUDA (CUExp aenv t)
    travE = codegenExp dev aenv

    travB :: forall sh e. Elt e
          => DelayedOpenAcc aenv (Array sh e) -> Boundary (EltRepr e) -> CUDA (Boundary (CUExp aenv e))
    travB _ Clamp        = return Clamp
    travB _ Mirror       = return Mirror
    travB _ Wrap         = return Wrap
    travB _ (Constant c) = return . Constant $ CUExp ([], codegenConst (Sugar.eltType (undefined::e)) c)

    -- caffeine and misery
    prim :: String
    prim                = showPreAccOp pacc
    unexpectedError     = $internalError "codegenAcc" $ "unexpected array primitive: " ++ prim
    fusionError         = $internalError "codegenAcc" $ "unexpected fusible material: " ++ prim


-- Scalar function abstraction
-- ---------------------------

-- Generate code for scalar function abstractions.
--
-- This is quite awkward: we have an outer monad to generate fresh variable
-- names, but since we know that even if the function in applied many times (for
-- example, collective operations such as 'fold' and 'scan'), the variables will
-- not shadow each other. Thus, we don't need fresh names at _every_ invocation
-- site, so we hack this a bit to return a pure closure.
--
-- Note that the implementation of def-use analysis used for dead code
-- elimination requires that we always generate code for closed functions.
-- Additionally, we require two passes over the function: once when performing
-- the analysis, and a second time when instantiating the function in the
-- skeleton.
--
codegenFun1
    :: forall aenv a b. DeviceProperties
    -> Gamma aenv
    -> DelayedFun aenv (a -> b)
    -> CUDA (CUFun1 aenv (a -> b))
codegenFun1 dev aenv fun
  | Lam (Body f) <- fun
  = let
        go :: Rvalue x => [x] -> Gen ([C.BlockItem], [C.Exp])
        go x = do
          code  <- mapM use =<< codegenOpenExp dev aenv f (Empty `Push` map rvalue x)
          env'  <- getEnv
          return (env', code)

        -- Initial code generation proceeds with dummy variable names. The real
        -- names are substituted later when we instantiate the skeleton.
        (_,u,_) = locals "undefined_x" (undefined :: a)
    in do
      n                 <- get
      ExpST _ used      <- execCGM (go u)
      return $ CUFun1 (mark used u)
             $ \xs -> evalState (evalCGM (go xs)) n
  --
  | otherwise
  = $internalError "codegenFun1" "expected unary function"


codegenFun2
    :: forall aenv a b c. DeviceProperties
    -> Gamma aenv
    -> DelayedFun aenv (a -> b -> c)
    -> CUDA (CUFun2 aenv (a -> b -> c))
codegenFun2 dev aenv fun
  | Lam (Lam (Body f)) <- fun
  = let
        go :: (Rvalue x, Rvalue y) => [x] -> [y] -> Gen ([C.BlockItem], [C.Exp])
        go x y = do
          code  <- mapM use =<< codegenOpenExp dev aenv f (Empty `Push` map rvalue x `Push` map rvalue y)
          env'  <- getEnv
          return (env', code)

        (_,u,_)  = locals "undefined_x" (undefined :: a)
        (_,v,_)  = locals "undefined_y" (undefined :: b)
    in do
      n                 <- get
      ExpST _ used      <- execCGM (go u v)
      return $ CUFun2 (mark used u) (mark used v)
             $ \xs ys -> evalState (evalCGM (go xs ys)) n
  --
  | otherwise
  = $internalError "codegenFun2" "expected binary function"


-- It is important to filter output terms of a function that will not be used.
-- Consider this pattern from the map kernel:
--
--   items:(x      .=. get ix)
--   items:(set ix .=. f x)
--
-- If this is applied to the following expression where we extract the first
-- component of a 4-tuple:
--
--   map (\t -> let (x,_,_,_) = unlift t in x) vec4
--
-- Then the first line 'get ix' still reads all four components of the input
-- vector, even though only one is used. Conversely, if we directly apply the
-- data fetch to f, then the redundant reads are eliminated, but this is simply
-- inlining the read into the function body, so if the argument is used multiple
-- times so to is the data read multiple times.
--
-- The procedure for determining which variables are used is to record each
-- singleton expression produced throughout code generation to a set. It doesn't
-- matter if the expression is a variable (which we are interested in) or
-- something else. Once generation completes, we can test which of the input
-- variables also appear in the output set. Later, we integrate this information
-- when assigning to l-values: if the variable is not in the set, simply elide
-- that statement.
--
-- In the above map example, this means that the usage data is taken from 'f',
-- but applies to which results of 'get ix' are committed to memory.
--
mark :: HashSet C.Exp -> [C.Exp] -> ([a] -> [(Bool,a)])
mark used xs
  = let flags = map (\x -> x `Set.member` used) xs
    in  zipWith (,) flags

visit :: [C.Exp] -> Gen [C.Exp]
visit exp
  | [x] <- exp  = use x >> return exp
  | otherwise   =          return exp


-- Scalar expressions
-- ------------------

-- Generation of scalar expressions
--
codegenExp :: DeviceProperties -> Gamma aenv -> DelayedExp aenv t -> CUDA (CUExp aenv t)
codegenExp dev aenv exp =
  evalCGM $ do
    code        <- codegenOpenExp dev aenv exp Empty
    env         <- getEnv
    return      $! CUExp (env,code)


-- The core of the code generator, buildings lists of untyped C expression
-- fragments. This is tricky to get right!
--
codegenOpenExp
    :: forall aenv env' t'. DeviceProperties
    -> Gamma aenv
    -> DelayedOpenExp env' aenv t'
    -> Val env'
    -> Gen [C.Exp]
codegenOpenExp dev aenv = cvtE
  where
    -- Generate code for a scalar expression in depth-first order. We run under
    -- a monad that generates fresh names and keeps track of let bindings.
    --
    cvtE :: forall env t. DelayedOpenExp env aenv t -> Val env -> Gen [C.Exp]
    cvtE exp env = visit =<<
      case exp of
        Let bnd body            -> elet bnd body env
        Var ix                  -> return $ prj ix env
        PrimConst c             -> return $ [codegenPrimConst c]
        Const c                 -> return $ codegenConst (Sugar.eltType (undefined::t)) c
        PrimApp f x             -> return <$> primApp f x env
        Tuple t                 -> cvtT t env
        Prj i t                 -> prjT i t exp env
        Cond p t e              -> cond p t e env
        While p f x             -> while p f x env

        -- Shapes and indices
        IndexNil                -> return []
        IndexAny                -> return []
        IndexCons sh sz         -> (++) <$> cvtE sh env <*> cvtE sz env
        IndexHead ix            -> return . cindexHead <$> cvtE ix env
        IndexTail ix            ->          cindexTail <$> cvtE ix env
        IndexSlice ix slix sh   -> indexSlice ix slix sh env
        IndexFull  ix slix sl   -> indexFull  ix slix sl env
        ToIndex sh ix           -> toIndex   sh ix env
        FromIndex sh ix         -> fromIndex sh ix env

        -- Arrays and indexing
        Index acc ix            -> index acc ix env
        LinearIndex acc ix      -> linearIndex acc ix env
        Shape acc               -> shape acc env
        ShapeSize sh            -> shapeSize sh env
        Intersect sh1 sh2       -> intersect sh1 sh2 env

        --Foreign function
        Foreign ff _ e          -> foreignE ff e env

    -- The heavy lifting
    -- -----------------

    -- Scalar let expressions evaluate their terms and generate new (const)
    -- variable bindings to store these results. These are carried the monad
    -- state, which also gives us a supply of fresh names. The new names are
    -- added to the environment for use in the body via the standard Var term.
    --
    -- Note that we have not restricted the scope of these new bindings: once
    -- something is added, it remains in scope forever. We are relying on
    -- liveness analysis of the CUDA compiler to manage register pressure.
    --
    elet :: DelayedOpenExp env aenv bnd -> DelayedOpenExp (env, bnd) aenv body -> Val env -> Gen [C.Exp]
    elet bnd body env = do
      bnd'      <- cvtE bnd env >>= pushEnv bnd
      body'     <- cvtE body (env `Push` bnd')
      return body'

    -- When evaluating primitive functions, we evaluate each argument to the
    -- operation as a statement expression. This is necessary to ensure proper
    -- short-circuit behaviour for logical operations.
    --
    primApp :: PrimFun (a -> b) -> DelayedOpenExp env aenv a -> Val env -> Gen C.Exp
    primApp f x env
      | Tuple (NilTup `SnocTup` a `SnocTup` b) <- x
      = codegenPrim2 f <$> cvtE' a env <*> cvtE' b env

      | otherwise
      = codegenPrim1 f <$> cvtE' x env
      where
        cvtE' :: DelayedOpenExp env aenv a -> Val env -> Gen C.Exp
        cvtE' e env = do
          (b,r) <- clean $ single "primApp" <$> cvtE e env
          if null b
             then return r
             else return [cexp| ({ $items:b; $exp:r; }) |]

    -- Convert an open expression into a sequence of C expressions. We retain
    -- snoc-list ordering, so the element at tuple index zero is at the end of
    -- the list. Note that nested tuple structures are flattened.
    --
    cvtT :: Tuple (DelayedOpenExp env aenv) t -> Val env -> Gen [C.Exp]
    cvtT tup env =
      case tup of
        NilTup          -> return []
        SnocTup t e     -> (++) <$> cvtT t env <*> cvtE e env

    -- Project out a tuple index. Since the nested tuple structure is flattened,
    -- this actually corresponds to slicing out a subset of the list of C
    -- expressions, rather than picking out a single element.
    --
    prjT :: forall env t e. TupleIdx (TupleRepr t) e
         -> DelayedOpenExp env aenv t
         -> DelayedOpenExp env aenv e
         -> Val env
         -> Gen [C.Exp]
    prjT ix t e env =
      let subset = reverse
                 . take (length      $ expType e)
                 . drop (prjToInt ix $ Sugar.preExpType Sugar.delayedAccType t)
                 . reverse
      in
      subset <$> cvtE t env

    -- Convert a tuple index into the corresponding integer. Since the internal
    -- representation is flat, be sure to walk over all sub components when indexing
    -- past nested tuples.
    --
    prjToInt :: TupleIdx t e -> TupleType a -> Int
    prjToInt ZeroTupIdx     _                 = 0
    prjToInt (SuccTupIdx i) (b `PairTuple` a) = sizeTupleType a + prjToInt i b
    prjToInt _              _                 = $internalError "prjToInt" "inconsistent valuation"

    sizeTupleType :: TupleType a -> Int
    sizeTupleType UnitTuple       = 0
    sizeTupleType (SingleTuple _) = 1
    sizeTupleType (PairTuple a b) = sizeTupleType a + sizeTupleType b

    -- Scalar conditionals insert a standard if/else statement block. We don't
    -- use the ternary expression operator (?:) because this forces all
    -- auxiliary bindings for both the true and false branches to always be
    -- evaluated before the correct result is chosen.
    --
    cond :: forall env t. Elt t
         => DelayedOpenExp env aenv Bool
         -> DelayedOpenExp env aenv t
         -> DelayedOpenExp env aenv t
         -> Val env -> Gen [C.Exp]
    cond p t f env = do
      p'        <- cvtE p env
      ok        <- single "Cond" <$> pushEnv p p'
      ifTrue    <- clean $ cvtE t env
      ifFalse   <- clean $ cvtE f env

      -- Generate names for the result variables, which will be initialised
      -- within each branch of the conditional. Twiddle the names a bit to
      -- avoid clobbering.
      var_r     <- lift fresh
      let (_, r, declr) = locals ('l':var_r) (undefined :: t)
          branch        = [citem| if ( $exp:ok ) {
                                      $items:(r .=. ifTrue)
                                  }
                                  else {
                                      $items:(r .=. ifFalse)
                                  } |]
                        : map C.BlockDecl declr

      modify (\s -> s { localBindings = branch ++ localBindings s })
      return r

    -- Value recursion
    --
    while :: forall env a. Elt a
          => DelayedOpenFun env aenv (a -> Bool)        -- continue while predicate returns true
          -> DelayedOpenFun env aenv (a -> a)           -- loop body
          -> DelayedOpenExp env aenv a                  -- initial value
          -> Val env
          -> Gen [C.Exp]
    while test step x env
      | Lam (Body p)    <- test
      , Lam (Body f)    <- step
      = do
           -- Generate code for the initial value, then bind this to a fresh
           -- (mutable) variable. We need build the declarations ourselves, and
           -- twiddle the names a bit to avoid clobbering.
           --
           x'           <- cvtE x env
           var_acc      <- lift fresh
           var_ok       <- lift fresh
           var_tmp      <- lift fresh

           let (_, acc, decl_acc) = locals ('l':var_acc) (undefined :: a)
               (_, ok,  decl_ok)  = locals ('l':var_ok)  (undefined :: Bool)
               (tmp, _, _)        = locals ('l':var_tmp) (undefined :: a)

           -- Generate code for the predicate and body expressions, with the new
           -- names baked in directly. We can't use 'codegenFun1', because
           -- def-use analysis won't be able to see into this new function.
           --
           -- However, we do need to generate the function with a clean set of
           -- local bindings, and extract and new declarations afterwards.
           --
           p'   <- clean $ cvtE p (env `Push` acc)
           f'   <- clean $ cvtE f (env `Push` acc)

           -- Piece it all together. Note that declarations are added to the
           -- localBindings in reverse order. Also, we have to be careful not to
           -- assign the results of f' direction into acc. Why? If some of the
           -- variables in acc are referenced in f', then we risk overwriting
           -- values that are still needed to computer f'.
           --
           let loop = [citem| while ( $exp:(single "while" ok) ) {
                                  $items:(tmp .=. f')
                                  $items:(acc .=. tmp)
                                  $items:(ok  .=. p')
                              } |]
                    : reverse (ok  .=. p')
                   ++ reverse (acc .=. x')
                   ++ map C.BlockDecl decl_ok
                   ++ map C.BlockDecl decl_acc

           modify (\s -> s { localBindings = loop ++ localBindings s })
           return acc

      | otherwise
      = error "Would you say we'd be venturing into a zone of danger?"

    -- Restrict indices based on a slice specification. In the SliceAll case we
    -- elide the presence of IndexAny from the head of slx, as this is not
    -- represented in by any C term (Any ~ [])
    --
    indexSlice :: SliceIndex (EltRepr slix) sl co (EltRepr sh)
               -> DelayedOpenExp env aenv slix
               -> DelayedOpenExp env aenv sh
               -> Val env
               -> Gen [C.Exp]
    indexSlice sliceIndex slix sh env =
      let restrict :: SliceIndex slix sl co sh -> [C.Exp] -> [C.Exp] -> [C.Exp]
          restrict SliceNil              _       _       = []
          restrict (SliceAll   sliceIdx) slx     (sz:sl) = sz : restrict sliceIdx slx sl
          restrict (SliceFixed sliceIdx) (_:slx) ( _:sl) =      restrict sliceIdx slx sl
          restrict _ _ _ = $internalError "IndexSlice" "unexpected shapes"
          --
          slice slix' sh' = reverse $ restrict sliceIndex (reverse slix') (reverse sh')
      in
      slice <$> cvtE slix env <*> cvtE sh env

    -- Extend indices based on a slice specification. In the SliceAll case we
    -- elide the presence of Any from the head of slx.
    --
    indexFull :: SliceIndex (EltRepr slix) (EltRepr sl) co sh
              -> DelayedOpenExp env aenv slix
              -> DelayedOpenExp env aenv sl
              -> Val env
              -> Gen [C.Exp]
    indexFull sliceIndex slix sl env =
      let extend :: SliceIndex slix sl co sh -> [C.Exp] -> [C.Exp] -> [C.Exp]
          extend SliceNil              _        _       = []
          extend (SliceAll   sliceIdx) slx      (sz:sh) = sz : extend sliceIdx slx sh
          extend (SliceFixed sliceIdx) (sz:slx) sh      = sz : extend sliceIdx slx sh
          extend _ _ _ = $internalError "IndexFull" "unexpected shapes"
          --
          replicate slix' sl' = reverse $ extend sliceIndex (reverse slix') (reverse sl')
      in
      replicate <$> cvtE slix env <*> cvtE sl env

    -- Convert between linear and multidimensional indices
    --
    toIndex :: DelayedOpenExp env aenv sh -> DelayedOpenExp env aenv sh -> Val env -> Gen [C.Exp]
    toIndex sh ix env = do
      sh'   <- mapM use =<< cvtE sh env
      ix'   <- mapM use =<< cvtE ix env
      return [ ctoIndex sh' ix' ]

    fromIndex :: DelayedOpenExp env aenv sh -> DelayedOpenExp env aenv Int -> Val env -> Gen [C.Exp]
    fromIndex sh ix env = do
      sh'   <- mapM use =<< cvtE sh env
      ix'   <- cvtE ix env
      tmp   <- lift fresh
      let (ls, sz) = cfromIndex sh' (single "fromIndex" ix') tmp
      modify (\st -> st { localBindings = reverse ls ++ localBindings st })
      return sz

    -- Project out a single scalar element from an array. The array expression
    -- does not contain any free scalar variables (strictly flat data
    -- parallelism) and has been floated out to be replaced by an array index.
    --
    -- As we have a non-parametric array representation, be sure to bind the
    -- linear array index as it will be used to access each component of a
    -- tuple.
    --
    -- Note that after evaluating the linear array index we bind this to a fresh
    -- variable of type 'int', so there is an implicit conversion from
    -- Int -> Int32.
    --
    index :: (Shape sh, Elt e)
          => DelayedOpenAcc aenv (Array sh e)
          -> DelayedOpenExp env aenv sh
          -> Val env
          -> Gen [C.Exp]
    index acc ix env
      | Manifest (Avar idx) <- acc
      = let (sh, arr)   = namesOfAvar aenv idx
            ty          = accType acc
        in do
        ix'     <- mapM use =<< cvtE ix env
        i       <- bind cint $ ctoIndex (cshape (expDim ix) sh) ix'
        return   $ zipWith (\t a -> indexArray dev t (cvar a) i) ty arr
      --
      | otherwise
      = $internalError "Index" "expected array variable"


    linearIndex :: (Shape sh, Elt e)
                => DelayedOpenAcc aenv (Array sh e)
                -> DelayedOpenExp env aenv Int
                -> Val env
                -> Gen [C.Exp]
    linearIndex acc ix env
      | Manifest (Avar idx) <- acc
      = let (_, arr)    = namesOfAvar aenv idx
            ty          = accType acc
        in do
        ix'     <- mapM use =<< cvtE ix env
        i       <- bind [cty| int |] $ single "LinearIndex" ix'
        return   $ zipWith (\t a -> indexArray dev t (cvar a) i) ty arr
      --
      | otherwise
      = $internalError "LinearIndex" "expected array variable"

    -- Array shapes created in this method refer to the shape of free array
    -- variables. As such, they are always passed as arguments to the kernel,
    -- not computed as part of the scalar expression. These shapes are
    -- transferred to the kernel as a structure, and so the individual fields
    -- need to be "unpacked", to work with our handling of tuple structures.
    --
    shape :: (Shape sh, Elt e) => DelayedOpenAcc aenv (Array sh e) -> Val env -> Gen [C.Exp]
    shape acc _env
      | Manifest (Avar idx) <- acc
      = return $ cshape (delayedDim acc) (fst (namesOfAvar aenv idx))

      | otherwise
      = $internalError "Shape" "expected array variable"

    -- The size of a shape, as the product of the extent in each dimension. The
    -- definition is inlined, but we could also call the C function helpers.
    --
    shapeSize :: DelayedOpenExp env aenv sh -> Val env -> Gen [C.Exp]
    shapeSize sh env = return . csize <$> cvtE sh env

    -- Intersection of two shapes, taken as the minimum in each dimension.
    --
    intersect :: forall env sh. Elt sh
              => DelayedOpenExp env aenv sh
              -> DelayedOpenExp env aenv sh
              -> Val env -> Gen [C.Exp]
    intersect sh1 sh2 env =
      zipWith (\a b -> ccall "min" [a,b]) <$> cvtE sh1 env <*> cvtE sh2 env

    -- Foreign scalar functions. We need to extract any header files that might
    -- be required so they can be added to the top level definitions.
    --
    -- Additionally, we insert an explicit type cast from the foreign function
    -- result back into Accelerate types (c.f. Int vs int).
    --
    foreignE :: forall f a b env. (Sugar.Foreign f, Elt a, Elt b)
             => f a b
             -> DelayedOpenExp env aenv a
             -> Val env
             -> Gen [C.Exp]
    foreignE ff x env = case canExecuteExp ff of
      Nothing      -> $internalError "codegenOpenExp" "Non-CUDA foreign expression encountered"
      Just (hs, f) -> do
        lift $ modify (\st -> st { headers = foldl (flip Set.insert) (headers st) hs })
        args    <- cvtE x env
        mapM_ use args
        return  $  [ccall f (ccastTup (Sugar.eltType (undefined::a)) args)]

    -- Execute a command in a new environment. The old environment is replaced
    -- on exit, and the result and any new bindings generated are returned.
    --
    clean :: Gen a -> Gen ([C.BlockItem], a)
    clean this = do
      env  <- state (\s -> ( localBindings s, s { localBindings = []  } ))
      r    <- this
      env' <- state (\s -> ( localBindings s, s { localBindings = env } ))
      return (reverse env', r)

    -- Some terms demand we extract only singly typed expressions
    --
    single :: String -> [C.Exp] -> C.Exp
    single _   [x] = x
    single loc _   = $internalError loc "expected single expression"


-- Scalar Primitives
-- -----------------

codegenPrimConst :: PrimConst a -> C.Exp
codegenPrimConst (PrimMinBound ty) = codegenMinBound ty
codegenPrimConst (PrimMaxBound ty) = codegenMaxBound ty
codegenPrimConst (PrimPi       ty) = codegenPi ty


codegenPrim1 :: PrimFun f -> C.Exp -> C.Exp
codegenPrim1 (PrimNeg              _) a   = [cexp| - $exp:a|]
codegenPrim1 (PrimAbs             ty) a   = codegenAbs ty a
codegenPrim1 (PrimSig             ty) a   = codegenSig ty a
codegenPrim1 (PrimBNot             _) a   = [cexp|~ $exp:a|]
codegenPrim1 (PrimRecip           ty) a   = codegenRecip ty a
codegenPrim1 (PrimSin             ty) a   = ccall (FloatingNumType ty `postfix` "sin")   [a]
codegenPrim1 (PrimCos             ty) a   = ccall (FloatingNumType ty `postfix` "cos")   [a]
codegenPrim1 (PrimTan             ty) a   = ccall (FloatingNumType ty `postfix` "tan")   [a]
codegenPrim1 (PrimAsin            ty) a   = ccall (FloatingNumType ty `postfix` "asin")  [a]
codegenPrim1 (PrimAcos            ty) a   = ccall (FloatingNumType ty `postfix` "acos")  [a]
codegenPrim1 (PrimAtan            ty) a   = ccall (FloatingNumType ty `postfix` "atan")  [a]
codegenPrim1 (PrimAsinh           ty) a   = ccall (FloatingNumType ty `postfix` "asinh") [a]
codegenPrim1 (PrimAcosh           ty) a   = ccall (FloatingNumType ty `postfix` "acosh") [a]
codegenPrim1 (PrimAtanh           ty) a   = ccall (FloatingNumType ty `postfix` "atanh") [a]
codegenPrim1 (PrimExpFloating     ty) a   = ccall (FloatingNumType ty `postfix` "exp")   [a]
codegenPrim1 (PrimSqrt            ty) a   = ccall (FloatingNumType ty `postfix` "sqrt")  [a]
codegenPrim1 (PrimLog             ty) a   = ccall (FloatingNumType ty `postfix` "log")   [a]
codegenPrim1 (PrimTruncate     ta tb) a   = codegenTruncate ta tb a
codegenPrim1 (PrimRound        ta tb) a   = codegenRound ta tb a
codegenPrim1 (PrimFloor        ta tb) a   = codegenFloor ta tb a
codegenPrim1 (PrimCeiling      ta tb) a   = codegenCeiling ta tb a
codegenPrim1 PrimLNot                 a   = [cexp| ! $exp:a|]
codegenPrim1 PrimOrd                  a   = codegenOrd a
codegenPrim1 PrimChr                  a   = codegenChr a
codegenPrim1 PrimBoolToInt            a   = codegenBoolToInt a
codegenPrim1 (PrimFromIntegral ta tb) a   = codegenFromIntegral ta tb a
codegenPrim1 _ _ =
  $internalError "codegenPrim1" "unknown primitive function"

codegenPrim2 :: PrimFun f -> C.Exp -> C.Exp -> C.Exp
codegenPrim2 (PrimAdd              _) a b = [cexp|$exp:a + $exp:b|]
codegenPrim2 (PrimSub              _) a b = [cexp|$exp:a - $exp:b|]
codegenPrim2 (PrimMul              _) a b = [cexp|$exp:a * $exp:b|]
codegenPrim2 (PrimQuot             _) a b = [cexp|$exp:a / $exp:b|]
codegenPrim2 (PrimRem              _) a b = [cexp|$exp:a % $exp:b|]
codegenPrim2 (PrimIDiv             _) a b = ccall "idiv" [a,b]
codegenPrim2 (PrimMod              _) a b = ccall "mod"  [a,b]
codegenPrim2 (PrimBAnd             _) a b = [cexp|$exp:a & $exp:b|]
codegenPrim2 (PrimBOr              _) a b = [cexp|$exp:a | $exp:b|]
codegenPrim2 (PrimBXor             _) a b = [cexp|$exp:a ^ $exp:b|]
codegenPrim2 (PrimBShiftL          _) a b = [cexp|$exp:a << $exp:b|]
codegenPrim2 (PrimBShiftR          _) a b = [cexp|$exp:a >> $exp:b|]
codegenPrim2 (PrimBRotateL         _) a b = ccall "rotateL" [a,b]
codegenPrim2 (PrimBRotateR         _) a b = ccall "rotateR" [a,b]
codegenPrim2 (PrimFDiv             _) a b = [cexp|$exp:a / $exp:b|]
codegenPrim2 (PrimFPow            ty) a b = ccall (FloatingNumType ty `postfix` "pow")   [a,b]
codegenPrim2 (PrimLogBase         ty) a b = codegenLogBase ty a b
codegenPrim2 (PrimAtan2           ty) a b = ccall (FloatingNumType ty `postfix` "atan2") [a,b]
codegenPrim2 (PrimLt               _) a b = [cexp|$exp:a < $exp:b|]
codegenPrim2 (PrimGt               _) a b = [cexp|$exp:a > $exp:b|]
codegenPrim2 (PrimLtEq             _) a b = [cexp|$exp:a <= $exp:b|]
codegenPrim2 (PrimGtEq             _) a b = [cexp|$exp:a >= $exp:b|]
codegenPrim2 (PrimEq               _) a b = [cexp|$exp:a == $exp:b|]
codegenPrim2 (PrimNEq              _) a b = [cexp|$exp:a != $exp:b|]
codegenPrim2 (PrimMax             ty) a b = codegenMax ty a b
codegenPrim2 (PrimMin             ty) a b = codegenMin ty a b
codegenPrim2 PrimLAnd                 a b = [cexp|$exp:a && $exp:b|]
codegenPrim2 PrimLOr                  a b = [cexp|$exp:a || $exp:b|]
codegenPrim2 _ _ _ =
  $internalError "codegenPrim2" "unknown primitive function"

-- Implementation of scalar primitives
--
codegenConst :: TupleType a -> a -> [C.Exp]
codegenConst UnitTuple           _      = []
codegenConst (SingleTuple ty)    c      = [codegenScalar ty c]
codegenConst (PairTuple ty1 ty0) (cs,c) = codegenConst ty1 cs ++ codegenConst ty0 c


-- Scalar constants
--
codegenScalar :: ScalarType a -> a -> C.Exp
codegenScalar (NumScalarType    ty) = codegenNumScalar ty
codegenScalar (NonNumScalarType ty) = codegenNonNumScalar ty

codegenNumScalar :: NumType a -> a -> C.Exp
codegenNumScalar (IntegralNumType ty) = codegenIntegralScalar ty
codegenNumScalar (FloatingNumType ty) = codegenFloatingScalar ty

codegenIntegralScalar :: IntegralType a -> a -> C.Exp
codegenIntegralScalar ty x | IntegralDict <- integralDict ty = [cexp| ( $ty:(codegenIntegralType ty) ) $exp:(cintegral x) |]

codegenFloatingScalar :: FloatingType a -> a -> C.Exp
codegenFloatingScalar (TypeFloat   _) x = C.Const (C.FloatConst (shows x "f") (toRational x) noLoc) noLoc
codegenFloatingScalar (TypeCFloat  _) x = C.Const (C.FloatConst (shows x "f") (toRational x) noLoc) noLoc
codegenFloatingScalar (TypeDouble  _) x = C.Const (C.DoubleConst (show x) (toRational x) noLoc) noLoc
codegenFloatingScalar (TypeCDouble _) x = C.Const (C.DoubleConst (show x) (toRational x) noLoc) noLoc

codegenNonNumScalar :: NonNumType a -> a -> C.Exp
codegenNonNumScalar (TypeBool   _) x = cbool x
codegenNonNumScalar (TypeChar   _) x = [cexp|$char:x|]
codegenNonNumScalar (TypeCChar  _) x = [cexp|$char:(chr (fromIntegral x))|]
codegenNonNumScalar (TypeCUChar _) x = [cexp|$char:(chr (fromIntegral x))|]
codegenNonNumScalar (TypeCSChar _) x = [cexp|$char:(chr (fromIntegral x))|]


-- Constant methods of floating
--
codegenPi :: FloatingType a -> C.Exp
codegenPi ty | FloatingDict <- floatingDict ty = codegenFloatingScalar ty pi


-- Constant methods of bounded
--
codegenMinBound :: BoundedType a -> C.Exp
codegenMinBound (IntegralBoundedType ty) | IntegralDict <- integralDict ty = codegenIntegralScalar ty minBound
codegenMinBound (NonNumBoundedType   ty) | NonNumDict   <- nonNumDict   ty = codegenNonNumScalar   ty minBound


codegenMaxBound :: BoundedType a -> C.Exp
codegenMaxBound (IntegralBoundedType ty) | IntegralDict <- integralDict ty = codegenIntegralScalar ty maxBound
codegenMaxBound (NonNumBoundedType   ty) | NonNumDict   <- nonNumDict   ty = codegenNonNumScalar   ty maxBound


-- Methods from Num, Floating, Fractional and RealFrac
--
codegenAbs :: NumType a -> C.Exp -> C.Exp
codegenAbs (FloatingNumType ty) x = ccall (FloatingNumType ty `postfix` "fabs") [x]
codegenAbs (IntegralNumType ty) x =
  case ty of
    TypeWord _          -> x
    TypeWord8 _         -> x
    TypeWord16 _        -> x
    TypeWord32 _        -> x
    TypeWord64 _        -> x
    TypeCUShort _       -> x
    TypeCUInt _         -> x
    TypeCULong _        -> x
    TypeCULLong _       -> x
    _                   -> ccall "abs" [x]


codegenSig :: NumType a -> C.Exp -> C.Exp
codegenSig (IntegralNumType ty) = codegenIntegralSig ty
codegenSig (FloatingNumType ty) = codegenFloatingSig ty

codegenIntegralSig :: IntegralType a -> C.Exp -> C.Exp
codegenIntegralSig ty x = [cexp|$exp:x == $exp:zero ? $exp:zero : $exp:(ccall "copysign" [one,x]) |]
  where
    zero | IntegralDict <- integralDict ty = codegenIntegralScalar ty 0
    one  | IntegralDict <- integralDict ty = codegenIntegralScalar ty 1

codegenFloatingSig :: FloatingType a -> C.Exp -> C.Exp
codegenFloatingSig ty x =
  [cexp|$exp:x == $exp:zero
            ? $exp:zero
            : $exp:(ccall (FloatingNumType ty `postfix` "copysign") [one,x]) |]
  where
    zero | FloatingDict <- floatingDict ty = codegenFloatingScalar ty 0
    one  | FloatingDict <- floatingDict ty = codegenFloatingScalar ty 1


codegenRecip :: FloatingType a -> C.Exp -> C.Exp
codegenRecip ty x | FloatingDict <- floatingDict ty = [cexp|$exp:(codegenFloatingScalar ty 1) / $exp:x|]


codegenLogBase :: FloatingType a -> C.Exp -> C.Exp -> C.Exp
codegenLogBase ty x y = let a = ccall (FloatingNumType ty `postfix` "log") [x]
                            b = ccall (FloatingNumType ty `postfix` "log") [y]
                        in
                        [cexp|$exp:b / $exp:a|]


codegenMin :: ScalarType a -> C.Exp -> C.Exp -> C.Exp
codegenMin (NumScalarType ty@(IntegralNumType _)) a b = ccall (ty `postfix` "min")  [a,b]
codegenMin (NumScalarType ty@(FloatingNumType _)) a b = ccall (ty `postfix` "fmin") [a,b]
codegenMin (NonNumScalarType _)                   a b =
  let ty = scalarType :: ScalarType Int32
  in  codegenMin ty (ccast ty a) (ccast ty b)


codegenMax :: ScalarType a -> C.Exp -> C.Exp -> C.Exp
codegenMax (NumScalarType ty@(IntegralNumType _)) a b = ccall (ty `postfix` "max")  [a,b]
codegenMax (NumScalarType ty@(FloatingNumType _)) a b = ccall (ty `postfix` "fmax") [a,b]
codegenMax (NonNumScalarType _)                   a b =
  let ty = scalarType :: ScalarType Int32
  in  codegenMax ty (ccast ty a) (ccast ty b)


-- Type coercions
--
codegenOrd :: C.Exp -> C.Exp
codegenOrd = ccast (scalarType :: ScalarType Int)

codegenChr :: C.Exp -> C.Exp
codegenChr = ccast (scalarType :: ScalarType Char)

codegenBoolToInt :: C.Exp -> C.Exp
codegenBoolToInt = ccast (scalarType :: ScalarType Int)

codegenFromIntegral :: IntegralType a -> NumType b -> C.Exp -> C.Exp
codegenFromIntegral _ ty = ccast (NumScalarType ty)

codegenTruncate :: FloatingType a -> IntegralType b -> C.Exp -> C.Exp
codegenTruncate ta tb x
  = ccast (NumScalarType (IntegralNumType tb))
  $ ccall (FloatingNumType ta `postfix` "trunc") [x]

codegenRound :: FloatingType a -> IntegralType b -> C.Exp -> C.Exp
codegenRound ta tb x
  = ccast (NumScalarType (IntegralNumType tb))
  $ ccall (FloatingNumType ta `postfix` "round") [x]

codegenFloor :: FloatingType a -> IntegralType b -> C.Exp -> C.Exp
codegenFloor ta tb x
  = ccast (NumScalarType (IntegralNumType tb))
  $ ccall (FloatingNumType ta `postfix` "floor") [x]

codegenCeiling :: FloatingType a -> IntegralType b -> C.Exp -> C.Exp
codegenCeiling ta tb x
  = ccast (NumScalarType (IntegralNumType tb))
  $ ccall (FloatingNumType ta `postfix` "ceil") [x]


-- Auxiliary Functions
-- -------------------

ccast :: ScalarType a -> C.Exp -> C.Exp
ccast ty x = [cexp|($ty:(codegenScalarType ty)) $exp:x|]

ccastTup :: TupleType e -> [C.Exp] -> [C.Exp]
ccastTup ty = fst . travTup ty
  where
    travTup :: TupleType e -> [C.Exp] -> ([C.Exp],[C.Exp])
    travTup UnitTuple         xs     = ([], xs)
    travTup (SingleTuple ty') (x:xs) = ([ccast ty' x], xs)
    travTup (PairTuple l r)   xs     = let
                                         (ls, xs' ) = travTup l xs
                                         (rs, xs'') = travTup r xs'
                                       in (ls ++ rs, xs'')
    travTup _ _                      = $internalError "ccastTup" "not enough expressions to match type"


postfix :: NumType a -> String -> String
postfix (FloatingNumType (TypeFloat  _)) x = x ++ "f"
postfix (FloatingNumType (TypeCFloat _)) x = x ++ "f"
postfix _                                x = x