{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE CPP, GADTs, BangPatterns, TypeOperators, PatternGuards #-}
{-# LANGUAGE TypeFamilies, ScopedTypeVariables, FlexibleContexts #-}
-- |
-- Module      : Data.Array.Accelerate.Interpreter
-- Copyright   : [2008..2011] Manuel M T Chakravarty, Gabriele Keller, Sean Lee, Trevor L. McDonell
-- License     : BSD3
--
-- Maintainer  : Manuel M T Chakravarty <chak@cse.unsw.edu.au>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- This interpreter is meant to be a reference implementation of the semantics
-- of the embedded array language.  The emphasis is on defining the semantics
-- clearly, not on performance.
--
-- /Surface types versus representation types/
--
-- As a general rule, we perform all computations on representation types and we store all data
-- as values of representation types.  To guarantee the type safety of the interpreter, this
-- currently implies a lot of conversions between surface and representation types.  Optimising
-- the code by eliminating back and forth conversions is fine, but only where it doesn't
-- negatively affects clarity — after all, the main purpose of the interpreter is to serve as an
-- executable specification.
--

module Data.Array.Accelerate.Interpreter (

  -- * Interpret an array expression
  Arrays, run, stream,

  -- Internal (hidden)
  evalPrim, evalPrimConst, evalPrj

) where

-- standard libraries
import Control.Monad
import Control.Monad.ST                                 ( ST )
import Data.Bits
import Data.Char                                        ( chr, ord )
import Prelude                                          hiding ( sum )

-- friends
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Array.Representation       hiding ( sliceIndex )
import Data.Array.Accelerate.Array.Sugar (
  Z(..), (:.)(..), Array(..), ArraysR(..), Arrays, Scalar, Vector, Segments)
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.Tuple
import Data.Array.Accelerate.Array.Delayed              hiding ( force, delay, Delayed )
import qualified Data.Array.Accelerate.Smart            as Sugar
import qualified Data.Array.Accelerate.Array.Sugar      as Sugar
import qualified Data.Array.Accelerate.Array.Delayed    as Sugar

#include "accelerate.h"


-- Program execution
-- -----------------

-- |Run a complete embedded array program using the reference interpreter.
--
run :: Arrays a => Sugar.Acc a -> a
run = force . evalAcc . Sugar.convertAcc

-- | Stream a lazily read list of input arrays through the given program,
-- collecting results as we go
--
stream :: (Arrays a, Arrays b) => (Sugar.Acc a -> Sugar.Acc b) -> [a] -> [b]
stream afun = map (run1 acc)
  where
    acc = Sugar.convertAccFun1 afun

    run1 :: Afun (a -> b) -> a -> b
    run1 (Alam (Abody f)) = \a -> force (evalOpenAcc f (Empty `Push` a))
    run1 _                = error "Hey type checker! We can not get here!"


-- Delayed arrays
--
type Delayed a = Sugar.Delayed (Sugar.ArrRepr a)

delay :: Arrays a => a -> Delayed a
delay arr = go (Sugar.arrays arr) (Sugar.fromArr arr)
  where
    go :: ArraysR a -> a -> Sugar.Delayed a
    go ArraysRunit         ()       = DelayedUnit
    go ArraysRarray        a        = Sugar.delay a
    go (ArraysRpair r1 r2) (a1, a2) = DelayedPair (go r1 a1) (go r2 a2)

force :: forall a. Arrays a => Delayed a -> a
force arr = Sugar.toArr $ go (Sugar.arrays (undefined::a)) arr
  where
    go :: ArraysR a' -> Sugar.Delayed a' -> a'
    go ArraysRunit         DelayedUnit         = ()
    go ArraysRarray        a                   = Sugar.force a
    go (ArraysRpair r1 r2) (DelayedPair d1 d2) = (go r1 d1, go r2 d2)


-- Array expression evaluation
-- ---------------------------

-- Evaluate an open array expression
--
evalOpenAcc :: OpenAcc aenv a -> Val aenv -> Delayed a
evalOpenAcc (OpenAcc acc) = evalPreOpenAcc acc

evalPreOpenAcc :: forall aenv a. PreOpenAcc OpenAcc aenv a -> Val aenv -> Delayed a

evalPreOpenAcc (Alet acc1 acc2) aenv
  = let !arr1 = force $ evalOpenAcc acc1 aenv
    in evalOpenAcc acc2 (aenv `Push` arr1)

evalPreOpenAcc (Avar idx) aenv = delay $ prj idx aenv

evalPreOpenAcc (Atuple tup) aenv = delay (toTuple $ evalAtuple tup aenv :: a)

evalPreOpenAcc (Aprj ix (tup :: OpenAcc aenv arrs)) aenv =
  let tup'  = force $ evalOpenAcc tup aenv :: arrs
  in  delay $ evalPrj ix (fromTuple tup')

evalPreOpenAcc (Apply (Alam (Abody funAcc)) acc) aenv
  = let !arr = force $ evalOpenAcc acc aenv
    in evalOpenAcc funAcc (Empty `Push` arr)
evalPreOpenAcc (Apply _afun _acc) _aenv
  = error "GHC's pattern match checker is too dumb to figure that this case is impossible"

evalPreOpenAcc (Acond cond acc1 acc2) aenv
  = if (evalExp cond aenv) then evalOpenAcc acc1 aenv else evalOpenAcc acc2 aenv

evalPreOpenAcc (Use arr) _aenv = delay (Sugar.toArr arr :: a)

evalPreOpenAcc (Unit e) aenv = unitOp (evalExp e aenv)

evalPreOpenAcc (Generate sh f) aenv
  = generateOp (evalExp sh aenv) (evalFun f aenv)

evalPreOpenAcc (Reshape e acc) aenv 
  = reshapeOp (evalExp e aenv) (evalOpenAcc acc aenv)

evalPreOpenAcc (Replicate sliceIndex slix acc) aenv
  = replicateOp sliceIndex (evalExp slix aenv) (evalOpenAcc acc aenv)
  
evalPreOpenAcc (Index sliceIndex acc slix) aenv
  = indexOp sliceIndex (evalOpenAcc acc aenv) (evalExp slix aenv)

evalPreOpenAcc (Map f acc) aenv = mapOp (evalFun f aenv) (evalOpenAcc acc aenv)

evalPreOpenAcc (ZipWith f acc1 acc2) aenv
  = zipWithOp (evalFun f aenv) (evalOpenAcc acc1 aenv) (evalOpenAcc acc2 aenv)

evalPreOpenAcc (Fold f e acc) aenv
  = foldOp (evalFun f aenv) (evalExp e aenv) (evalOpenAcc acc aenv)

evalPreOpenAcc (Fold1 f acc) aenv
  = fold1Op (evalFun f aenv) (evalOpenAcc acc aenv)

evalPreOpenAcc (FoldSeg f e acc1 acc2) aenv
  = foldSegOp integralType
              (evalFun f aenv) (evalExp e aenv)
              (evalOpenAcc acc1 aenv) (evalOpenAcc acc2 aenv)

evalPreOpenAcc (Fold1Seg f acc1 acc2) aenv
  = fold1SegOp integralType
               (evalFun f aenv) (evalOpenAcc acc1 aenv) (evalOpenAcc acc2 aenv)

evalPreOpenAcc (Scanl f e acc) aenv
  = scanlOp (evalFun f aenv) (evalExp e aenv) (evalOpenAcc acc aenv)

evalPreOpenAcc (Scanl' f e acc) aenv
  = scanl'Op (evalFun f aenv) (evalExp e aenv) (evalOpenAcc acc aenv)

evalPreOpenAcc (Scanl1 f acc) aenv
  = scanl1Op (evalFun f aenv) (evalOpenAcc acc aenv)

evalPreOpenAcc (Scanr f e acc) aenv
  = scanrOp (evalFun f aenv) (evalExp e aenv) (evalOpenAcc acc aenv)

evalPreOpenAcc (Scanr' f e acc) aenv
  = scanr'Op (evalFun f aenv) (evalExp e aenv) (evalOpenAcc acc aenv)

evalPreOpenAcc (Scanr1 f acc) aenv
  = scanr1Op (evalFun f aenv) (evalOpenAcc acc aenv)

evalPreOpenAcc (Permute f dftAcc p acc) aenv
  = permuteOp (evalFun f aenv) (evalOpenAcc dftAcc aenv)
              (evalFun p aenv) (evalOpenAcc acc aenv)

evalPreOpenAcc (Backpermute e p acc) aenv
  = backpermuteOp (evalExp e aenv) (evalFun p aenv) (evalOpenAcc acc aenv)

evalPreOpenAcc (Stencil sten bndy acc) aenv
  = stencilOp (evalFun sten aenv) bndy (evalOpenAcc acc aenv)

evalPreOpenAcc (Stencil2 sten bndy1 acc1 bndy2 acc2) aenv
  = stencil2Op (evalFun sten aenv) bndy1 (evalOpenAcc acc1 aenv) bndy2 (evalOpenAcc acc2 aenv)

-- Evaluate a closed array expressions
--
evalAcc :: Acc a -> Delayed a
evalAcc acc = evalOpenAcc acc Empty


-- Array tuple construction and projection
--
evalAtuple :: Atuple (OpenAcc aenv) t -> Val aenv -> t
evalAtuple NilAtup        _    = ()
evalAtuple (SnocAtup t a) aenv = (evalAtuple t aenv, force $ evalOpenAcc a aenv)


-- Array primitives
-- ----------------

unitOp :: Sugar.Elt e => e -> Delayed (Scalar e)
unitOp e
  = DelayedPair DelayedUnit
  $ DelayedArray {shapeDA = (), repfDA = const (Sugar.fromElt e)}

generateOp :: (Sugar.Shape dim, Sugar.Elt e)
      => dim
      -> (dim -> e)
      -> Delayed (Array dim e)
generateOp sh rf
  = DelayedPair DelayedUnit
  $ DelayedArray (Sugar.fromElt sh) (Sugar.sinkFromElt rf)

reshapeOp :: Sugar.Shape dim
          => dim -> Delayed (Array dim' e) -> Delayed (Array dim e)
reshapeOp newShape darr@(DelayedPair DelayedUnit (DelayedArray {shapeDA = oldShape}))
  = let Array _ adata = force darr
    in
    BOUNDS_CHECK(check) "reshape" "shape mismatch" (Sugar.size newShape == size oldShape)
    $ delay $ Array (Sugar.fromElt newShape) adata

replicateOp :: (Sugar.Shape dim, Sugar.Elt slix)
            => SliceIndex (Sugar.EltRepr slix)
                          (Sugar.EltRepr sl)
                          co
                          (Sugar.EltRepr dim)
            -> slix
            -> Delayed (Array sl e)
            -> Delayed (Array dim e)
replicateOp sliceIndex slix (DelayedPair DelayedUnit (DelayedArray sh pf))
  = DelayedPair DelayedUnit (DelayedArray sh' (pf . pf'))
  where
    (sh', pf') = extend sliceIndex (Sugar.fromElt slix) sh

    extend :: SliceIndex slix sl co dim
           -> slix
           -> sl
           -> (dim, dim -> sl)
    extend (SliceNil)            ()        ()       = ((), const ())
    extend (SliceAll sliceIdx)   (slx, ()) (sl, sz)
      = let (dim', f') = extend sliceIdx slx sl
        in
        ((dim', sz), \(ix, i) -> (f' ix, i))
    extend (SliceFixed sliceIdx) (slx, sz) sl
      = let (dim', f') = extend sliceIdx slx sl
        in
        ((dim', sz), \(ix, _) -> f' ix)

indexOp :: (Sugar.Shape sl, Sugar.Elt slix)
        => SliceIndex (Sugar.EltRepr slix)
                      (Sugar.EltRepr sl)
                      co
                      (Sugar.EltRepr dim)
        -> Delayed (Array dim e)
        -> slix
        -> Delayed (Array sl e)
indexOp sliceIndex (DelayedPair DelayedUnit (DelayedArray sh pf)) slix
  = DelayedPair DelayedUnit (DelayedArray sh' (pf . pf'))
  where
    (sh', pf') = restrict sliceIndex (Sugar.fromElt slix) sh

    restrict :: SliceIndex slix sl co dim
             -> slix
             -> dim
             -> (sl, sl -> dim)
    restrict (SliceNil)            ()        ()       = ((), const ())
    restrict (SliceAll sliceIdx)   (slx, ()) (sl, sz)
      = let (sl', f') = restrict sliceIdx slx sl
        in
        ((sl', sz), \(ix, i) -> (f' ix, i))
    restrict (SliceFixed sliceIdx) (slx, i)  (sl, sz)
      = let (sl', f') = restrict sliceIdx slx sl
        in
        BOUNDS_CHECK(checkIndex) "index" i sz $ (sl', \ix -> (f' ix, i))

mapOp :: Sugar.Elt e' 
      => (e -> e') 
      -> Delayed (Array dim e) 
      -> Delayed (Array dim e')
mapOp f (DelayedPair DelayedUnit (DelayedArray sh rf))
  = DelayedPair DelayedUnit
  $ DelayedArray sh (Sugar.sinkFromElt f . rf)

zipWithOp :: Sugar.Elt e3
          => (e1 -> e2 -> e3) 
          -> Delayed (Array dim e1) 
          -> Delayed (Array dim e2) 
          -> Delayed (Array dim e3)
zipWithOp f (DelayedPair DelayedUnit (DelayedArray sh1 rf1)) (DelayedPair DelayedUnit (DelayedArray sh2 rf2))
  = DelayedPair DelayedUnit
  $ DelayedArray (sh1 `intersect` sh2) 
                 (\ix -> (Sugar.sinkFromElt2 f) (rf1 ix) (rf2 ix))

foldOp :: Sugar.Shape dim
       => (e -> e -> e)
       -> e
       -> Delayed (Array (dim:.Int) e)
       -> Delayed (Array dim e)
foldOp f e (DelayedPair DelayedUnit (DelayedArray (sh, n) rf))
  | size sh == 0
  = DelayedPair DelayedUnit
  $ DelayedArray (listToShape . map (max 1) . shapeToList $ sh)
      (\_ -> Sugar.fromElt e)
  --
  | otherwise
  = DelayedPair DelayedUnit
  $ DelayedArray sh
      (\ix -> iter ((), n) (\((), i) -> rf (ix, i)) (Sugar.sinkFromElt2 f) (Sugar.fromElt e))

fold1Op :: Sugar.Shape dim
        => (e -> e -> e)
        -> Delayed (Array (dim:.Int) e)
        -> Delayed (Array dim e)
fold1Op f (DelayedPair DelayedUnit (DelayedArray (sh, n) rf))
  = DelayedPair DelayedUnit
  $ DelayedArray sh (\ix -> iter1 ((), n) (\((), i) -> rf (ix, i)) (Sugar.sinkFromElt2 f))

foldSegOp :: IntegralType i
          -> (e -> e -> e)
          -> e
          -> Delayed (Array (dim:.Int) e)
          -> Delayed (Segments i)
          -> Delayed (Array (dim:.Int) e)
foldSegOp ty f e arr seg
  | IntegralDict <- integralDict ty = foldSegOp' f e arr seg

foldSegOp' :: forall i e dim. Integral i
           => (e -> e -> e)
           -> e
           -> Delayed (Array (dim:.Int) e)
           -> Delayed (Segments i)
           -> Delayed (Array (dim:.Int) e)
foldSegOp' f e (DelayedPair DelayedUnit (DelayedArray (sh, _n) rf)) seg@(DelayedPair DelayedUnit (DelayedArray shSeg rfSeg))
  = delay arr
  where
    DelayedPair (DelayedPair DelayedUnit (DelayedArray _shSeg rfStarts)) _ = scanl'Op (+) 0 seg
    arr = Sugar.newArray (Sugar.toElt (sh, Sugar.toElt shSeg)) foldOne
    --
    foldOne :: dim:.Int -> e
    foldOne ix = let
                   (ix', i) = Sugar.fromElt ix
                   start    = fromIntegral ((Sugar.liftToElt rfStarts) i :: i)
                   len      = fromIntegral ((Sugar.liftToElt rfSeg)    i :: i)
                 in
                 fold ix' e start (start + len)

    fold :: Sugar.EltRepr dim -> e -> Int -> Int -> e
    fold ix' !v j end
      | j >= end  = v
      | otherwise = fold ix' (f v (Sugar.toElt . rf $ (ix', j))) (j + 1) end


fold1SegOp :: IntegralType i
           -> (e -> e -> e)
           -> Delayed (Array (dim:.Int) e)
           -> Delayed (Segments i)
           -> Delayed (Array (dim:.Int) e)
fold1SegOp ty f arr seg
  | IntegralDict <- integralDict ty = fold1SegOp' f arr seg


fold1SegOp' :: forall i e dim. Integral i
            => (e -> e -> e)
            -> Delayed (Array (dim:.Int) e)
            -> Delayed (Segments i)
            -> Delayed (Array (dim:.Int) e)
fold1SegOp' f (DelayedPair DelayedUnit (DelayedArray (sh, _n) rf)) seg@(DelayedPair DelayedUnit (DelayedArray shSeg rfSeg))
  = delay arr
  where
    DelayedPair prefix _sum                                = scanl'Op (+) 0 seg
    DelayedPair DelayedUnit (DelayedArray _shSeg rfStarts) = prefix
    arr = Sugar.newArray (Sugar.toElt (sh, Sugar.toElt shSeg)) foldOne
    --
    foldOne :: dim:.Int -> e
    foldOne ix = let
                   (ix', i) = Sugar.fromElt ix
                   start    = fromIntegral ((Sugar.liftToElt rfStarts) i :: i)
                   len      = fromIntegral ((Sugar.liftToElt rfSeg)    i :: i)
                 in
                 if len == 0
                   then
                     BOUNDS_ERROR(error) "fold1Seg" "empty iteration space"
                   else
                     fold ix' (Sugar.toElt . rf $ (ix', start)) (start + 1) (start + len)

    fold :: Sugar.EltRepr dim -> e -> Int -> Int -> e
    fold ix' !v j end
      | j >= end  = v
      | otherwise = fold ix' (f v (Sugar.toElt . rf $ (ix', j))) (j + 1) end


scanlOp :: forall e. (e -> e -> e)
        -> e
        -> Delayed (Vector e)
        -> Delayed (Vector e)
scanlOp f e (DelayedPair DelayedUnit (DelayedArray sh rf))
  = delay $ adata `seq` Array ((), n + 1) adata
  where
    n  = size sh
    f' = Sugar.sinkFromElt2 f
    --
    (adata, _) = runArrayData $ do
                   arr   <- newArrayData (n + 1)
                   final <- traverse arr 0 (Sugar.fromElt e)
                   writeArrayData arr n final
                   return (arr, undefined)

    traverse :: MutableArrayData s (Sugar.EltRepr e) -> Int -> (Sugar.EltRepr e) -> ST s (Sugar.EltRepr e)
    traverse arr i v
      | i >= n    = return v
      | otherwise = do
                      writeArrayData arr i v
                      traverse arr (i + 1) (f' v (rf ((), i)))

scanl'Op :: forall e. (e -> e -> e)
         -> e
         -> Delayed (Vector e)
         -> Delayed (Vector e, Scalar e)
scanl'Op f e (DelayedPair DelayedUnit (DelayedArray sh rf))
  = DelayedPair (delay $ adata `seq` Array sh adata) final
  where
    n  = size sh
    f' = Sugar.sinkFromElt2 f
    --
    DelayedPair DelayedUnit final = unitOp (Sugar.toElt asum)

    (adata, asum) = runArrayData $ do
                      arr <- newArrayData n
                      sum <- traverse arr 0 (Sugar.fromElt e)
                      return (arr, sum)

    traverse :: MutableArrayData s (Sugar.EltRepr e) -> Int -> (Sugar.EltRepr e) -> ST s (Sugar.EltRepr e)
    traverse arr i v
      | i >= n    = return v
      | otherwise = do
                      writeArrayData arr i v
                      traverse arr (i + 1) (f' v (rf ((), i)))

scanl1Op :: forall e. (e -> e -> e)
         -> Delayed (Vector e)
         -> Delayed (Vector e)
scanl1Op f (DelayedPair DelayedUnit (DelayedArray sh rf))
  = delay $ adata `seq` Array sh adata
  where
    n  = size sh
    f' = Sugar.sinkFromElt2 f
    --
    (adata, _) = runArrayData $ do
                   arr <- newArrayData n
                   traverse arr 0 undefined
                   return (arr, undefined)

    traverse :: MutableArrayData s (Sugar.EltRepr e) -> Int -> (Sugar.EltRepr e) -> ST s ()
    traverse arr i v
      | i >= n    = return ()
      | i == 0    = do
                      let e = rf ((), i)
                      writeArrayData arr i e
                      traverse arr (i + 1) e
      | otherwise = do
                      let e = f' v (rf ((), i))
                      writeArrayData arr i e
                      traverse arr (i + 1) e

scanrOp :: forall e. (e -> e -> e)
        -> e
        -> Delayed (Vector e)
        -> Delayed (Vector e)
scanrOp f e (DelayedPair DelayedUnit (DelayedArray sh rf))
  = delay $ adata `seq` Array ((), n + 1) adata
  where
    n  = size sh
    f' = Sugar.sinkFromElt2 f
    --
    (adata, _) = runArrayData $ do
                   arr   <- newArrayData (n + 1)
                   final <- traverse arr n (Sugar.fromElt e)
                   writeArrayData arr 0 final
                   return (arr, undefined)

    traverse :: MutableArrayData s (Sugar.EltRepr e) -> Int -> (Sugar.EltRepr e) -> ST s (Sugar.EltRepr e)
    traverse arr i v
      | i == 0    = return v
      | otherwise = do
                      writeArrayData arr i v
                      traverse arr (i - 1) (f' v (rf ((), i-1)))

scanr'Op :: forall e. (e -> e -> e)
         -> e
         -> Delayed (Vector e)
         -> Delayed (Vector e, Scalar e)
scanr'Op f e (DelayedPair DelayedUnit (DelayedArray sh rf))
  = DelayedPair (delay $ adata `seq` Array sh adata) final
  where
    n  = size sh
    f' = Sugar.sinkFromElt2 f
    --
    DelayedPair DelayedUnit final = unitOp (Sugar.toElt asum)

    (adata, asum) = runArrayData $ do
                      arr <- newArrayData n
                      sum <- traverse arr (n-1) (Sugar.fromElt e)
                      return (arr, sum)

    traverse :: MutableArrayData s (Sugar.EltRepr e) -> Int -> (Sugar.EltRepr e) -> ST s (Sugar.EltRepr e)
    traverse arr i v
      | i < 0     = return v
      | otherwise = do
                      writeArrayData arr i v
                      traverse arr (i - 1) (f' v (rf ((), i)))

scanr1Op :: forall e. (e -> e -> e)
         -> Delayed (Vector e)
         -> Delayed (Vector e)
scanr1Op f (DelayedPair DelayedUnit (DelayedArray sh rf))
  = delay $ adata `seq` Array sh adata
  where
    n  = size sh
    f' = Sugar.sinkFromElt2 f
    --
    (adata, _) = runArrayData $ do
                   arr <- newArrayData n
                   traverse arr (n - 1) undefined
                   return (arr, undefined)

    traverse :: MutableArrayData s (Sugar.EltRepr e) -> Int -> (Sugar.EltRepr e) -> ST s ()
    traverse arr i v
      | i < 0        = return ()
      | i == (n - 1) = do
                         let e = rf ((), i)
                         writeArrayData arr i e
                         traverse arr (i - 1) e
      | otherwise    = do
                         let e = f' v (rf ((), i))
                         writeArrayData arr i e
                         traverse arr (i - 1) e

permuteOp :: (e -> e -> e)
          -> Delayed (Array dim' e)
          -> (dim -> dim')
          -> Delayed (Array dim e)
          -> Delayed (Array dim' e)
permuteOp f (DelayedPair DelayedUnit (DelayedArray dftsSh dftsPf))
          p (DelayedPair DelayedUnit (DelayedArray sh pf))
  = delay $ adata `seq` Array dftsSh adata
  where 
    f' = Sugar.sinkFromElt2 f
    --
    (adata, _) 
      = runArrayData $ do

            -- new array in target dimension
          arr <- newArrayData (size dftsSh)

            -- initialise it with the default values
          let write ix = writeArrayData arr (index dftsSh ix) (dftsPf ix)      
          iter dftsSh write (>>) (return ())

            -- traverse the source dimension and project each element into
            -- the target dimension (where it gets combined with the current
            -- default)
          let update ix = do
                            let target = (Sugar.sinkFromElt p) ix
                            unless (target == ignore) $ do
                              let i = index dftsSh target
                              e <- readArrayData arr i
                              writeArrayData arr i (pf ix `f'` e) 
          iter sh update (>>) (return ())

            -- return the updated array
          return (arr, undefined)

backpermuteOp :: Sugar.Shape dim'
              => dim'
              -> (dim' -> dim)
              -> Delayed (Array dim e)
              -> Delayed (Array dim' e)
backpermuteOp sh' p (DelayedPair DelayedUnit (DelayedArray _sh rf))
  = DelayedPair DelayedUnit
  $ DelayedArray (Sugar.fromElt sh') (rf . Sugar.sinkFromElt p)

stencilOp :: forall dim e e' stencil. (Sugar.Elt e, Sugar.Elt e', Stencil dim e stencil)
          => (stencil -> e')
          -> Boundary (Sugar.EltRepr e)
          -> Delayed (Array dim e)
          -> Delayed (Array dim e')
stencilOp sten bndy (DelayedPair DelayedUnit (DelayedArray sh rf))
  = DelayedPair DelayedUnit
  $ DelayedArray sh rf'
  where
    rf' = Sugar.sinkFromElt (sten . stencilAccess rfBounded)

    -- add a boundary to the source array as specified by the boundary condition
    rfBounded :: dim -> e
    rfBounded ix = Sugar.toElt $ case Sugar.bound (Sugar.toElt sh) ix bndy of
                                    Left v    -> v
                                    Right ix' -> rf (Sugar.fromElt ix')

stencil2Op :: forall dim e1 e2 e' stencil1 stencil2. 
              (Sugar.Elt e1, Sugar.Elt e2, Sugar.Elt e', 
               Stencil dim e1 stencil1, Stencil dim e2 stencil2)
           => (stencil1 -> stencil2 -> e')
           -> Boundary (Sugar.EltRepr e1)
           -> Delayed (Array dim e1)
           -> Boundary (Sugar.EltRepr e2)
           -> Delayed (Array dim e2)
           -> Delayed (Array dim e')
stencil2Op sten bndy1 (DelayedPair DelayedUnit (DelayedArray sh1 rf1))
                bndy2 (DelayedPair DelayedUnit (DelayedArray sh2 rf2))
  = DelayedPair DelayedUnit (DelayedArray (sh1 `intersect` sh2) rf')
  where
    rf' = Sugar.sinkFromElt (\ix -> sten (stencilAccess rf1Bounded ix)
                                         (stencilAccess rf2Bounded ix))

    -- add a boundary to the source arrays as specified by the boundary conditions
    rf1Bounded :: dim -> e1
    rf1Bounded ix = Sugar.toElt $ case Sugar.bound (Sugar.toElt sh1) ix bndy1 of
                                     Left v    -> v
                                     Right ix' -> rf1 (Sugar.fromElt ix')

    rf2Bounded :: dim -> e2
    rf2Bounded ix = Sugar.toElt $ case Sugar.bound (Sugar.toElt sh2) ix bndy2 of
                                     Left v    -> v
                                     Right ix' -> rf2 (Sugar.fromElt ix')


-- Expression evaluation
-- ---------------------

-- Evaluate open function
--
evalOpenFun :: OpenFun env aenv t -> ValElt env -> Val aenv -> t
evalOpenFun (Body e) env aenv = evalOpenExp e env aenv
evalOpenFun (Lam f)  env aenv 
  = \x -> evalOpenFun f (env `PushElt` Sugar.fromElt x) aenv

-- Evaluate a closed function
--
evalFun :: Fun aenv t -> Val aenv -> t
evalFun f aenv = evalOpenFun f EmptyElt aenv

-- Evaluate an open expression
--
-- NB: The implementation of 'IndexScalar' and 'Shape' demonstrate clearly why
--     array expressions must be hoisted out of scalar expressions before code
--     execution.  If these operations are in the body of a function that
--     gets mapped over an array, the array argument would be forced many times
--     leading to a large amount of wasteful recomputation.
--  
evalOpenExp :: OpenExp env aenv a -> ValElt env -> Val aenv -> a

evalOpenExp (Let exp1 exp2) env aenv
  = let !v1 = evalOpenExp exp1 env aenv
    in evalOpenExp exp2 (env `PushElt` Sugar.fromElt v1) aenv

evalOpenExp (Var idx) env _
  = prjElt idx env

evalOpenExp (Const c) _ _
  = Sugar.toElt c

evalOpenExp (Tuple tup) env aenv 
  = toTuple $ evalTuple tup env aenv

evalOpenExp (Prj idx e) env aenv 
  = evalPrj idx (fromTuple $ evalOpenExp e env aenv)

evalOpenExp IndexNil _env _aenv 
  = Z

evalOpenExp (IndexCons sh i) env aenv 
  = evalOpenExp sh env aenv :. evalOpenExp i env aenv

evalOpenExp (IndexHead ix) env aenv 
  = case evalOpenExp ix env aenv of _:.h -> h

evalOpenExp (IndexTail ix) env aenv 
  = case evalOpenExp ix env aenv of t:._ -> t

evalOpenExp (IndexAny) _ _
  = Sugar.Any

evalOpenExp (Cond c t e) env aenv 
  = if evalOpenExp c env aenv
    then evalOpenExp t env aenv
    else evalOpenExp e env aenv

evalOpenExp (PrimConst c) _ _ = evalPrimConst c

evalOpenExp (PrimApp p arg) env aenv 
  = evalPrim p (evalOpenExp arg env aenv)

evalOpenExp (IndexScalar acc ix) env aenv 
  = case evalOpenAcc acc aenv of
      DelayedPair DelayedUnit (DelayedArray sh pf) ->
        let ix' = Sugar.fromElt $ evalOpenExp ix env aenv
        in
        index sh ix' `seq` (Sugar.toElt $ pf ix')
                              -- FIXME: This is ugly, but (possibly) needed to
                              --       ensure bounds checking

evalOpenExp (Shape acc) _ aenv 
  = case evalOpenAcc acc aenv of
      DelayedPair DelayedUnit (DelayedArray sh _) -> Sugar.toElt sh

evalOpenExp (ShapeSize sh) env aenv = size $ Sugar.fromElt $ evalOpenExp sh env aenv

-- Evaluate a closed expression
--
evalExp :: Exp aenv t -> Val aenv -> t
evalExp e aenv = evalOpenExp e EmptyElt aenv


-- Scalar primitives
-- -----------------

evalPrimConst :: PrimConst a -> a
evalPrimConst (PrimMinBound ty) = evalMinBound ty
evalPrimConst (PrimMaxBound ty) = evalMaxBound ty
evalPrimConst (PrimPi       ty) = evalPi ty

evalPrim :: PrimFun p -> p
evalPrim (PrimAdd             ty) = evalAdd ty
evalPrim (PrimSub             ty) = evalSub ty
evalPrim (PrimMul             ty) = evalMul ty
evalPrim (PrimNeg             ty) = evalNeg ty
evalPrim (PrimAbs             ty) = evalAbs ty
evalPrim (PrimSig             ty) = evalSig ty
evalPrim (PrimQuot            ty) = evalQuot ty
evalPrim (PrimRem             ty) = evalRem ty
evalPrim (PrimIDiv            ty) = evalIDiv ty
evalPrim (PrimMod             ty) = evalMod ty
evalPrim (PrimBAnd            ty) = evalBAnd ty
evalPrim (PrimBOr             ty) = evalBOr ty
evalPrim (PrimBXor            ty) = evalBXor ty
evalPrim (PrimBNot            ty) = evalBNot ty
evalPrim (PrimBShiftL         ty) = evalBShiftL ty
evalPrim (PrimBShiftR         ty) = evalBShiftR ty
evalPrim (PrimBRotateL        ty) = evalBRotateL ty
evalPrim (PrimBRotateR        ty) = evalBRotateR ty
evalPrim (PrimFDiv            ty) = evalFDiv ty
evalPrim (PrimRecip           ty) = evalRecip ty
evalPrim (PrimSin             ty) = evalSin ty
evalPrim (PrimCos             ty) = evalCos ty
evalPrim (PrimTan             ty) = evalTan ty
evalPrim (PrimAsin            ty) = evalAsin ty
evalPrim (PrimAcos            ty) = evalAcos ty
evalPrim (PrimAtan            ty) = evalAtan ty
evalPrim (PrimAsinh           ty) = evalAsinh ty
evalPrim (PrimAcosh           ty) = evalAcosh ty
evalPrim (PrimAtanh           ty) = evalAtanh ty
evalPrim (PrimExpFloating     ty) = evalExpFloating ty
evalPrim (PrimSqrt            ty) = evalSqrt ty
evalPrim (PrimLog             ty) = evalLog ty
evalPrim (PrimFPow            ty) = evalFPow ty
evalPrim (PrimLogBase         ty) = evalLogBase ty
evalPrim (PrimTruncate     ta tb) = evalTruncate ta tb
evalPrim (PrimRound        ta tb) = evalRound ta tb
evalPrim (PrimFloor        ta tb) = evalFloor ta tb
evalPrim (PrimCeiling      ta tb) = evalCeiling ta tb
evalPrim (PrimAtan2           ty) = evalAtan2 ty
evalPrim (PrimLt              ty) = evalLt ty
evalPrim (PrimGt              ty) = evalGt ty
evalPrim (PrimLtEq            ty) = evalLtEq ty
evalPrim (PrimGtEq            ty) = evalGtEq ty
evalPrim (PrimEq              ty) = evalEq ty
evalPrim (PrimNEq             ty) = evalNEq ty
evalPrim (PrimMax             ty) = evalMax ty
evalPrim (PrimMin             ty) = evalMin ty
evalPrim PrimLAnd                 = evalLAnd
evalPrim PrimLOr                  = evalLOr
evalPrim PrimLNot                 = evalLNot
evalPrim PrimOrd                  = evalOrd
evalPrim PrimChr                  = evalChr
evalPrim PrimBoolToInt            = evalBoolToInt
evalPrim (PrimFromIntegral ta tb) = evalFromIntegral ta tb


-- Tuple construction and projection
-- ---------------------------------

evalTuple :: Tuple (OpenExp env aenv) t -> ValElt env -> Val aenv -> t
evalTuple NilTup            _env _aenv = ()
evalTuple (tup `SnocTup` e) env  aenv  = (evalTuple tup env aenv, evalOpenExp e env aenv)

evalPrj :: TupleIdx t e -> t -> e
evalPrj ZeroTupIdx       (!_, v)   = v
evalPrj (SuccTupIdx idx) (tup, !_) = evalPrj idx tup
  -- FIXME: Strictly speaking, we ought to force all components of a tuples;
  --        not only those that we happen to encounter during the recursive
  --        walk.


-- Implementation of scalar primitives
-- -----------------------------------

evalLAnd :: (Bool, Bool) -> Bool
evalLAnd (!x, !y) = x && y

evalLOr  :: (Bool, Bool) -> Bool
evalLOr (!x, !y) = x || y

evalLNot :: Bool -> Bool
evalLNot = not

evalOrd :: Char -> Int
evalOrd = ord

evalChr :: Int -> Char
evalChr = chr

evalBoolToInt :: Bool -> Int
evalBoolToInt = fromEnum

evalFromIntegral :: IntegralType a -> NumType b -> a -> b
evalFromIntegral ta (IntegralNumType tb)
  | IntegralDict <- integralDict ta
  , IntegralDict <- integralDict tb = fromIntegral
evalFromIntegral ta (FloatingNumType tb)
  | IntegralDict <- integralDict ta
  , FloatingDict <- floatingDict tb = fromIntegral


-- Extract methods from reified dictionaries
-- 

-- Constant methods of Bounded
-- 

evalMinBound :: BoundedType a -> a
evalMinBound (IntegralBoundedType ty) 
  | IntegralDict <- integralDict ty = minBound
evalMinBound (NonNumBoundedType   ty) 
  | NonNumDict   <- nonNumDict ty   = minBound

evalMaxBound :: BoundedType a -> a
evalMaxBound (IntegralBoundedType ty) 
  | IntegralDict <- integralDict ty = maxBound
evalMaxBound (NonNumBoundedType   ty) 
  | NonNumDict   <- nonNumDict ty   = maxBound

-- Constant method of floating
-- 

evalPi :: FloatingType a -> a
evalPi ty | FloatingDict <- floatingDict ty = pi

evalSin :: FloatingType a -> (a -> a)
evalSin ty | FloatingDict <- floatingDict ty = sin

evalCos :: FloatingType a -> (a -> a)
evalCos ty | FloatingDict <- floatingDict ty = cos

evalTan :: FloatingType a -> (a -> a)
evalTan ty | FloatingDict <- floatingDict ty = tan

evalAsin :: FloatingType a -> (a -> a)
evalAsin ty | FloatingDict <- floatingDict ty = asin

evalAcos :: FloatingType a -> (a -> a)
evalAcos ty | FloatingDict <- floatingDict ty = acos

evalAtan :: FloatingType a -> (a -> a)
evalAtan ty | FloatingDict <- floatingDict ty = atan

evalAsinh :: FloatingType a -> (a -> a)
evalAsinh ty | FloatingDict <- floatingDict ty = asinh

evalAcosh :: FloatingType a -> (a -> a)
evalAcosh ty | FloatingDict <- floatingDict ty = acosh

evalAtanh :: FloatingType a -> (a -> a)
evalAtanh ty | FloatingDict <- floatingDict ty = atanh

evalExpFloating :: FloatingType a -> (a -> a)
evalExpFloating ty | FloatingDict <- floatingDict ty = exp

evalSqrt :: FloatingType a -> (a -> a)
evalSqrt ty | FloatingDict <- floatingDict ty = sqrt

evalLog :: FloatingType a -> (a -> a)
evalLog ty | FloatingDict <- floatingDict ty = log

evalFPow :: FloatingType a -> ((a, a) -> a)
evalFPow ty | FloatingDict <- floatingDict ty = uncurry (**)

evalLogBase :: FloatingType a -> ((a, a) -> a)
evalLogBase ty | FloatingDict <- floatingDict ty = uncurry logBase

evalTruncate :: FloatingType a -> IntegralType b -> (a -> b)
evalTruncate ta tb
  | FloatingDict <- floatingDict ta
  , IntegralDict <- integralDict tb = truncate

evalRound :: FloatingType a -> IntegralType b -> (a -> b)
evalRound ta tb
  | FloatingDict <- floatingDict ta
  , IntegralDict <- integralDict tb = round

evalFloor :: FloatingType a -> IntegralType b -> (a -> b)
evalFloor ta tb
  | FloatingDict <- floatingDict ta
  , IntegralDict <- integralDict tb = floor

evalCeiling :: FloatingType a -> IntegralType b -> (a -> b)
evalCeiling ta tb
  | FloatingDict <- floatingDict ta
  , IntegralDict <- integralDict tb = ceiling

evalAtan2 :: FloatingType a -> ((a, a) -> a)
evalAtan2 ty | FloatingDict <- floatingDict ty = uncurry atan2


-- Methods of Num
-- 

evalAdd :: NumType a -> ((a, a) -> a)
evalAdd (IntegralNumType ty) | IntegralDict <- integralDict ty = uncurry (+)
evalAdd (FloatingNumType ty) | FloatingDict <- floatingDict ty = uncurry (+)

evalSub :: NumType a -> ((a, a) -> a)
evalSub (IntegralNumType ty) | IntegralDict <- integralDict ty = uncurry (-)
evalSub (FloatingNumType ty) | FloatingDict <- floatingDict ty = uncurry (-)

evalMul :: NumType a -> ((a, a) -> a)
evalMul (IntegralNumType ty) | IntegralDict <- integralDict ty = uncurry (*)
evalMul (FloatingNumType ty) | FloatingDict <- floatingDict ty = uncurry (*)

evalNeg :: NumType a -> (a -> a)
evalNeg (IntegralNumType ty) | IntegralDict <- integralDict ty = negate
evalNeg (FloatingNumType ty) | FloatingDict <- floatingDict ty = negate

evalAbs :: NumType a -> (a -> a)
evalAbs (IntegralNumType ty) | IntegralDict <- integralDict ty = abs
evalAbs (FloatingNumType ty) | FloatingDict <- floatingDict ty = abs

evalSig :: NumType a -> (a -> a)
evalSig (IntegralNumType ty) | IntegralDict <- integralDict ty = signum
evalSig (FloatingNumType ty) | FloatingDict <- floatingDict ty = signum

evalQuot :: IntegralType a -> ((a, a) -> a)
evalQuot ty | IntegralDict <- integralDict ty = uncurry quot

evalRem :: IntegralType a -> ((a, a) -> a)
evalRem ty | IntegralDict <- integralDict ty = uncurry rem

evalIDiv :: IntegralType a -> ((a, a) -> a)
evalIDiv ty | IntegralDict <- integralDict ty = uncurry div

evalMod :: IntegralType a -> ((a, a) -> a)
evalMod ty | IntegralDict <- integralDict ty = uncurry mod

evalBAnd :: IntegralType a -> ((a, a) -> a)
evalBAnd ty | IntegralDict <- integralDict ty = uncurry (.&.)

evalBOr :: IntegralType a -> ((a, a) -> a)
evalBOr ty | IntegralDict <- integralDict ty = uncurry (.|.)

evalBXor :: IntegralType a -> ((a, a) -> a)
evalBXor ty | IntegralDict <- integralDict ty = uncurry xor

evalBNot :: IntegralType a -> (a -> a)
evalBNot ty | IntegralDict <- integralDict ty = complement

evalBShiftL :: IntegralType a -> ((a, Int) -> a)
evalBShiftL ty | IntegralDict <- integralDict ty = uncurry shiftL

evalBShiftR :: IntegralType a -> ((a, Int) -> a)
evalBShiftR ty | IntegralDict <- integralDict ty = uncurry shiftR

evalBRotateL :: IntegralType a -> ((a, Int) -> a)
evalBRotateL ty | IntegralDict <- integralDict ty = uncurry rotateL

evalBRotateR :: IntegralType a -> ((a, Int) -> a)
evalBRotateR ty | IntegralDict <- integralDict ty = uncurry rotateR

evalFDiv :: FloatingType a -> ((a, a) -> a)
evalFDiv ty | FloatingDict <- floatingDict ty = uncurry (/)

evalRecip :: FloatingType a -> (a -> a)
evalRecip ty | FloatingDict <- floatingDict ty = recip



evalLt :: ScalarType a -> ((a, a) -> Bool)
evalLt (NumScalarType (IntegralNumType ty)) 
  | IntegralDict <- integralDict ty = uncurry (<)
evalLt (NumScalarType (FloatingNumType ty)) 
  | FloatingDict <- floatingDict ty = uncurry (<)
evalLt (NonNumScalarType ty) 
  | NonNumDict   <- nonNumDict ty   = uncurry (<)

evalGt :: ScalarType a -> ((a, a) -> Bool)
evalGt (NumScalarType (IntegralNumType ty)) 
  | IntegralDict <- integralDict ty = uncurry (>)
evalGt (NumScalarType (FloatingNumType ty)) 
  | FloatingDict <- floatingDict ty = uncurry (>)
evalGt (NonNumScalarType ty) 
  | NonNumDict   <- nonNumDict ty   = uncurry (>)

evalLtEq :: ScalarType a -> ((a, a) -> Bool)
evalLtEq (NumScalarType (IntegralNumType ty)) 
  | IntegralDict <- integralDict ty = uncurry (<=)
evalLtEq (NumScalarType (FloatingNumType ty)) 
  | FloatingDict <- floatingDict ty = uncurry (<=)
evalLtEq (NonNumScalarType ty) 
  | NonNumDict   <- nonNumDict ty   = uncurry (<=)

evalGtEq :: ScalarType a -> ((a, a) -> Bool)
evalGtEq (NumScalarType (IntegralNumType ty)) 
  | IntegralDict <- integralDict ty = uncurry (>=)
evalGtEq (NumScalarType (FloatingNumType ty)) 
  | FloatingDict <- floatingDict ty = uncurry (>=)
evalGtEq (NonNumScalarType ty) 
  | NonNumDict   <- nonNumDict ty   = uncurry (>=)

evalEq :: ScalarType a -> ((a, a) -> Bool)
evalEq (NumScalarType (IntegralNumType ty)) 
  | IntegralDict <- integralDict ty = uncurry (==)
evalEq (NumScalarType (FloatingNumType ty)) 
  | FloatingDict <- floatingDict ty = uncurry (==)
evalEq (NonNumScalarType ty) 
  | NonNumDict   <- nonNumDict ty   = uncurry (==)

evalNEq :: ScalarType a -> ((a, a) -> Bool)
evalNEq (NumScalarType (IntegralNumType ty)) 
  | IntegralDict <- integralDict ty = uncurry (/=)
evalNEq (NumScalarType (FloatingNumType ty)) 
  | FloatingDict <- floatingDict ty = uncurry (/=)
evalNEq (NonNumScalarType ty) 
  | NonNumDict   <- nonNumDict ty   = uncurry (/=)

evalMax :: ScalarType a -> ((a, a) -> a)
evalMax (NumScalarType (IntegralNumType ty)) 
  | IntegralDict <- integralDict ty = uncurry max
evalMax (NumScalarType (FloatingNumType ty)) 
  | FloatingDict <- floatingDict ty = uncurry max
evalMax (NonNumScalarType ty) 
  | NonNumDict   <- nonNumDict ty   = uncurry max

evalMin :: ScalarType a -> ((a, a) -> a)
evalMin (NumScalarType (IntegralNumType ty)) 
  | IntegralDict <- integralDict ty = uncurry min
evalMin (NumScalarType (FloatingNumType ty)) 
  | FloatingDict <- floatingDict ty = uncurry min
evalMin (NonNumScalarType ty) 
  | NonNumDict   <- nonNumDict ty   = uncurry min