{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TypeFamilies        #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE ViewPatterns        #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.CodeGen.Exp
-- Copyright   : [2015..2017] Trevor L. McDonell
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <tmcdonell@cse.unsw.edu.au>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.CodeGen.Exp
  where

import Control.Applicative                                          hiding ( Const )
import Control.Monad
import Data.Proxy
import Data.Typeable
import Text.Printf
import Prelude                                                      hiding ( exp, any )
import qualified Data.IntMap                                        as IM

import Data.Array.Accelerate.AST                                    hiding ( Val(..), prj )
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Array.Sugar                            hiding ( Foreign, toTuple, shape, intersect, union )
import Data.Array.Accelerate.Array.Representation                   ( SliceIndex(..) )
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Product
import Data.Array.Accelerate.Trafo
import Data.Array.Accelerate.Type
import qualified Data.Array.Accelerate.Array.Sugar                  as A

import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Operand

import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.CodeGen.Type                      hiding ( typeOf )
import Data.Array.Accelerate.LLVM.Foreign
import qualified Data.Array.Accelerate.LLVM.CodeGen.Arithmetic      as A
import qualified Data.Array.Accelerate.LLVM.CodeGen.Loop            as L


-- Scalar expressions
-- ==================

{-# INLINEABLE llvmOfFun1 #-}
llvmOfFun1
    :: Foreign arch
    => arch
    -> DelayedFun aenv (a -> b)
    -> Gamma aenv
    -> IRFun1 arch aenv (a -> b)
llvmOfFun1 arch (Lam (Body body)) aenv = IRFun1 $ \x -> llvmOfOpenExp arch body (Empty `Push` x) aenv
llvmOfFun1 _ _ _                       = $internalError "llvmOfFun1" "impossible evaluation"

{-# INLINEABLE llvmOfFun2 #-}
llvmOfFun2
    :: Foreign arch
    => arch
    -> DelayedFun aenv (a -> b -> c)
    -> Gamma aenv
    -> IRFun2 arch aenv (a -> b -> c)
llvmOfFun2 arch (Lam (Lam (Body body))) aenv = IRFun2 $ \x y -> llvmOfOpenExp arch body (Empty `Push` x `Push` y) aenv
llvmOfFun2 _ _ _                             = $internalError "llvmOfFun2" "impossible evaluation"


-- | Convert an open scalar expression into a sequence of LLVM IR instructions.
-- Code is generated in depth first order, and uses a monad to collect the
-- sequence of instructions used to construct basic blocks.
--
{-# INLINEABLE llvmOfOpenExp #-}
llvmOfOpenExp
    :: forall arch env aenv _t. Foreign arch
    => arch
    -> DelayedOpenExp env aenv _t
    -> Val env
    -> Gamma aenv
    -> IROpenExp arch env aenv _t
llvmOfOpenExp arch top env aenv = cvtE top
  where
    cvtM :: DelayedOpenAcc aenv (Array sh e) -> IRManifest arch aenv (Array sh e)
    cvtM (Manifest (Avar ix)) = IRManifest ix
    cvtM _                    = $internalError "llvmOfOpenExp" "expected manifest array variable"

    cvtF1 :: DelayedOpenFun env aenv (a -> b) -> IROpenFun1 arch env aenv (a -> b)
    cvtF1 (Lam (Body body)) = IRFun1 $ \x -> llvmOfOpenExp arch body (env `Push` x) aenv
    cvtF1 _                 = $internalError "cvtF1" "impossible evaluation"

    cvtE :: forall t. DelayedOpenExp env aenv t -> IROpenExp arch env aenv t
    cvtE exp =
      case exp of
        Let bnd body                -> do x <- cvtE bnd
                                          llvmOfOpenExp arch body (env `Push` x) aenv
        Var ix                      -> return $ prj ix env
        Const c                     -> return $ IR (constant (eltType (undefined::t)) c)
        PrimConst c                 -> return $ IR (constant (eltType (undefined::t)) (fromElt (primConst c)))
        PrimApp f x                 -> primFun f x
        Undef                       -> return undefE
        IndexNil                    -> return indexNil
        IndexAny                    -> return indexAny
        IndexCons sh sz             -> indexCons <$> cvtE sh <*> cvtE sz
        IndexHead ix                -> indexHead <$> cvtE ix
        IndexTail ix                -> indexTail <$> cvtE ix
        Prj ix tup                  -> prjT ix =<< cvtE tup
        Tuple tup                   -> cvtT tup
        Foreign asm f x             -> foreignE asm f =<< cvtE x
        Cond c t e                  -> A.ifThenElse (cvtE c) (cvtE t) (cvtE e)
        IndexSlice slice slix sh    -> indexSlice slice <$> cvtE slix <*> cvtE sh
        IndexFull slice slix sh     -> indexFull slice  <$> cvtE slix <*> cvtE sh
        ToIndex sh ix               -> join $ intOfIndex <$> cvtE sh <*> cvtE ix
        FromIndex sh ix             -> join $ indexOfInt <$> cvtE sh <*> cvtE ix
        Index acc ix                -> index (cvtM acc)       =<< cvtE ix
        LinearIndex acc ix          -> linearIndex (cvtM acc) =<< cvtE ix
        ShapeSize sh                -> shapeSize              =<< cvtE sh
        Shape acc                   -> return $ shape (cvtM acc)
        Intersect sh1 sh2           -> join $ intersect <$> cvtE sh1 <*> cvtE sh2
        Union sh1 sh2               -> join $ union     <$> cvtE sh1 <*> cvtE sh2
        While c f x                 -> while (cvtF1 c) (cvtF1 f) (cvtE x)
        Coerce x                    -> coerce =<< cvtE x

    indexNil :: IR Z
    indexNil = IR (constant (eltType Z) (fromElt Z))

    indexAny :: forall sh. Shape sh => IR (Any sh)
    indexAny = let any = Any :: Any sh
               in  IR (constant (eltType any) (fromElt any))

    undefE :: forall t. Elt t => IR t
    undefE = IR $ go (eltType (undefined::t))
      where
        go :: TupleType s -> Operands s
        go TypeRunit       = OP_Unit
        go (TypeRscalar t) = ir' t (undef t)
        go (TypeRpair a b) = OP_Pair (go a) (go b)

    indexSlice :: SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
               -> IR slix
               -> IR sh
               -> IR sl
    indexSlice slice (IR slix) (IR sh) = IR $ restrict slice slix sh
      where
        restrict :: SliceIndex slix sl co sh -> Operands slix -> Operands sh -> Operands sl
        restrict SliceNil              OP_Unit               OP_Unit          = OP_Unit
        restrict (SliceAll sliceIdx)   (OP_Pair slx OP_Unit) (OP_Pair sl sz)  =
          let sl' = restrict sliceIdx slx sl
          in  OP_Pair sl' sz
        restrict (SliceFixed sliceIdx) (OP_Pair slx _i)      (OP_Pair sl _sz) =
          restrict sliceIdx slx sl

    indexFull :: SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
              -> IR slix
              -> IR sl
              -> IR sh
    indexFull slice (IR slix) (IR sh) = IR $ extend slice slix sh
      where
        extend :: SliceIndex slix sl co sh -> Operands slix -> Operands sl -> Operands sh
        extend SliceNil              OP_Unit               OP_Unit         = OP_Unit
        extend (SliceAll sliceIdx)   (OP_Pair slx OP_Unit) (OP_Pair sl sz) =
          let sh' = extend sliceIdx slx sl
          in  OP_Pair sh' sz
        extend (SliceFixed sliceIdx) (OP_Pair slx sz)      sl              =
          let sh' = extend sliceIdx slx sl
          in  OP_Pair sh' sz

    prjT :: forall t e. (Elt t, IsTuple t, Elt e) => TupleIdx (TupleRepr t) e -> IR t -> CodeGen (IR e)
    prjT tix (IR tup) =
      case eltType (undefined::t) of
        TypeRscalar (VectorScalarType v) -> goV tix v tup
        t                                -> goT tix t tup
      where
        -- for unzipped tuples
        goT :: TupleIdx s e -> TupleType t' -> Operands t' -> CodeGen (IR e)
        goT (SuccTupIdx ix) (TypeRpair t _) (OP_Pair x _) = goT ix t x
        goT ZeroTupIdx      (TypeRpair _ t) (OP_Pair _ x)
          | Just Refl <- matchTupleType t (eltType (undefined::e))
          = return $ IR x
        goT _ _ _
          = $internalError "prjT/tup" "inconsistent valuation"

        -- for SIMD vectors
        goV :: forall (v :: * -> *) a. TupleIdx (ProdRepr t) e -> VectorType (v a) -> Operands (v a) -> CodeGen (IR e)
        goV vix v (op' v -> vec)
          | Just Refl <- matchProdR (prod Proxy (undefined::t)) (vecProdR v)
          , Just Refl <- matchVecT (eltType (undefined::e)) (vecElemT v)
          = instr $ ExtractElement v vix vec
        goV _ _ _
          = $internalError "prjT/vec" "inconsistent valuation"

        matchVecT :: TupleType (EltRepr e) -> TupleType a -> Maybe (e :~: a)
        matchVecT e v
          | Just Refl <- matchTupleType e v = gcast Refl
          | otherwise                       = Nothing

        vecElemT :: VectorType (v a) -> TupleType a
        vecElemT (Vector2Type  a) = TypeRscalar (SingleScalarType a)
        vecElemT (Vector3Type  a) = TypeRscalar (SingleScalarType a)
        vecElemT (Vector4Type  a) = TypeRscalar (SingleScalarType a)
        vecElemT (Vector8Type  a) = TypeRscalar (SingleScalarType a)
        vecElemT (Vector16Type a) = TypeRscalar (SingleScalarType a)

        matchProdR :: ProdR Elt a -> ProdR Elt b -> Maybe (a :~: b)
        matchProdR ProdRunit        ProdRunit        = Just Refl
        matchProdR pa@(ProdRsnoc a) pb@(ProdRsnoc b)
          | Just Refl <- matchProdR a b
          , Just Refl <- matchTop pa pb
          = Just Refl
          where
            matchTop :: forall ta tb a b. (Elt a, Elt b) => ProdR Elt (ta,a) -> ProdR Elt (tb,b) -> Maybe (a :~: b)
            matchTop _ _
              | Just Refl <- matchTupleType (eltType (undefined::a)) (eltType (undefined::b)) = gcast Refl
              | otherwise                                                                     = Nothing
        matchProdR _ _
          = Nothing

        vecProdR :: VectorType v -> ProdR Elt (ProdRepr v)
        vecProdR (Vector2Type  e) | EltDict :: EltDict a <- singleElt e = prod Proxy (undefined::V2 a)
        vecProdR (Vector3Type  e) | EltDict :: EltDict a <- singleElt e = prod Proxy (undefined::V3 a)
        vecProdR (Vector4Type  e) | EltDict :: EltDict a <- singleElt e = prod Proxy (undefined::V4 a)
        vecProdR (Vector8Type  e) | EltDict :: EltDict a <- singleElt e = prod Proxy (undefined::V8 a)
        vecProdR (Vector16Type e) | EltDict :: EltDict a <- singleElt e = prod Proxy (undefined::V16 a)

    cvtT :: forall t. (Elt t, IsTuple t) => Tuple (DelayedOpenExp env aenv) (TupleRepr t) -> CodeGen (IR t)
    cvtT tup =
      case eltType (undefined::t) of
        TypeRscalar (VectorScalarType v) -> IR <$> goV v tup
        t                                -> IR <$> goT t tup
      where
        -- for unzipped tuples
        goT :: TupleType t' -> Tuple (DelayedOpenExp env aenv) tup -> CodeGen (Operands t')
        goT TypeRunit NilTup
          = return OP_Unit
        goT (TypeRpair ta tb) (SnocTup a (b :: DelayedOpenExp env aenv b))
          -- We must assert that the reified type 'tb' of 'b' is actually
          -- equivalent to the type of 'b'. This can not fail, but is necessary
          -- because 'tb' observes the representation type of surface type 'b'.
          | Just Refl <- matchTupleType tb (eltType (undefined::b))
          = do a'    <- goT ta a
               IR b' <- cvtE b
               return $ OP_Pair a' b'
        goT _ _
          = $internalError "cvtT/tup"
          $ unlines [ "impossible evaluation"
                    , "  possible solution: ensure that the 'EltRepr' and 'ProdRepr' instances of your data type are consistent." ]

        -- for packed SIMD vectors
        goV :: forall (v :: * -> *) a. VectorType (v a) -> Tuple (DelayedOpenExp env aenv) (TupleRepr t) -> CodeGen (Operands (v a))
        goV v ts = ir' v . snd <$> pack ts
          where
            pack :: Tuple (DelayedOpenExp env aenv) tup -> CodeGen (Int32, Operand (v a))
            pack NilTup
              = return (0, undef (VectorScalarType v))
            pack (SnocTup t x)
              | Just Refl <- matchExpType x
              = do
                  x'        <- cvtE x
                  (i, vec)  <- pack t
                  vec'      <- instr' $ InsertElement i vec (op a x')
                  return (i+1, vec')
                where
                  matchExpType :: forall s. Elt s => DelayedOpenExp env aenv s -> Maybe (s :~: a)
                  matchExpType _
                    | Just Refl <- matchTupleType (eltType (undefined::s)) (TypeRscalar (SingleScalarType a)) = gcast Refl
                    | otherwise                                                                               = Nothing
            pack _
              = $internalError "cvtT/vec" "impossible evaluation"

            a :: SingleType a
            a = case v of
                  Vector2Type  t -> t
                  Vector3Type  t -> t
                  Vector4Type  t -> t
                  Vector8Type  t -> t
                  Vector16Type t -> t

    linearIndex :: (Shape sh, Elt e) => IRManifest arch aenv (Array sh e) -> IR Int -> CodeGen (IR e)
    linearIndex (IRManifest v) ix =
      readArray (irArray (aprj v aenv)) ix

    index :: (Shape sh, Elt e) => IRManifest arch aenv (Array sh e) -> IR sh -> CodeGen (IR e)
    index (IRManifest v) ix =
      let arr = irArray (aprj v aenv)
      in  readArray arr =<< intOfIndex (irArrayShape arr) ix

    shape :: (Shape sh, Elt e) => IRManifest arch aenv (Array sh e) -> IR sh
    shape (IRManifest v) = irArrayShape (irArray (aprj v aenv))

    shapeSize :: forall sh. Shape sh => IR sh -> CodeGen (IR Int)
    shapeSize (IR extent) = go (eltType (undefined::sh)) extent
      where
        go :: TupleType t -> Operands t -> CodeGen (IR Int)
        go TypeRunit OP_Unit
          = return $ IR (constant (eltType (undefined :: Int)) 1)
        go (TypeRpair tsh t) (OP_Pair sh sz)
          | Just Refl <- matchTupleType t (eltType (undefined::Int))
          = do
               a <- go tsh sh
               b <- A.mul numType a (IR sz)
               return b
        go (TypeRscalar t) (op' t -> i)
          | Just Refl <- matchScalarType t (scalarType :: ScalarType Int)
          = return $ ir t i
        go _ _
          = $internalError "shapeSize" "expected shape with Int components"

    intersect :: forall sh. Shape sh => IR sh -> IR sh -> CodeGen (IR sh)
    intersect (IR extent1) (IR extent2) = IR <$> go (eltType (undefined::sh)) extent1 extent2
      where
        go :: TupleType t -> Operands t -> Operands t -> CodeGen (Operands t)
        go TypeRunit OP_Unit OP_Unit
          = return OP_Unit
        go (TypeRscalar t) sh1 sh2
          | Just Refl <- matchScalarType t (scalarType :: ScalarType Int)       -- TLM: GHC hang if this is omitted
          = do IR x <- A.min (singleType :: SingleType Int) (IR sh1) (IR sh2)
               return x
        go (TypeRpair tsh tsz) (OP_Pair sh1 sz1) (OP_Pair sh2 sz2)
          = do
               sz' <- go tsz sz1 sz2
               sh' <- go tsh sh1 sh2
               return $ OP_Pair sh' sz'
        go _ _ _
          = $internalError "intersect" "expected shape with Int components"

    union :: forall sh. Shape sh => IR sh -> IR sh -> CodeGen (IR sh)
    union (IR extent1) (IR extent2) = IR <$> go (eltType (undefined::sh)) extent1 extent2
      where
        go :: TupleType t -> Operands t -> Operands t -> CodeGen (Operands t)
        go TypeRunit OP_Unit OP_Unit
          = return OP_Unit
        go (TypeRscalar t) sh1 sh2
          | Just Refl <- matchScalarType t (scalarType :: ScalarType Int)       -- TLM: GHC hang if this is omitted
          = do IR x <- A.max (singleType :: SingleType Int) (IR sh1) (IR sh2)
               return x
        go (TypeRpair tsh tsz) (OP_Pair sh1 sz1) (OP_Pair sh2 sz2)
          = do
               sz' <- go tsz sz1 sz2
               sh' <- go tsh sh1 sh2
               return $ OP_Pair sh' sz'
        go _ _ _
          = $internalError "union" "expected shape with Int components"

    while :: Elt a
          => IROpenFun1 arch env aenv (a -> Bool)
          -> IROpenFun1 arch env aenv (a -> a)
          -> IROpenExp  arch env aenv a
          -> IROpenExp  arch env aenv a
    while p f x =
      L.while (app1 p) (app1 f) =<< x

    foreignE :: (Elt a, Elt b, Foreign arch, A.Foreign asm)
             => asm           (a -> b)
             -> DelayedFun () (a -> b)
             -> IR a
             -> IRExp arch () b
    foreignE asm no x =
      case foreignExp arch asm of
        Just f                       -> app1 f x
        Nothing | Lam (Body b) <- no -> llvmOfOpenExp arch b (Empty `Push` x) IM.empty
        _                            -> error "when a grid's misaligned with another behind / that's a moiré..."

    coerce :: forall a b. (Elt a, Elt b) => IR a -> CodeGen (IR b)
    coerce (IR as) = IR <$> go (eltType (undefined::a)) (eltType (undefined::b)) as
      where
        go :: TupleType s -> TupleType t -> Operands s -> CodeGen (Operands t)
        go TypeRunit         TypeRunit         OP_Unit         = return OP_Unit
        go (TypeRpair s1 s2) (TypeRpair t1 t2) (OP_Pair x1 x2) = OP_Pair <$> go s1 t1 x1 <*> go s2 t2 x2
        go (TypeRscalar s)   (TypeRscalar t)   x
          | Just Refl <- matchScalarType s t = return x
          | otherwise                        = ir' t <$> instr' (BitCast t (op' s x))
        --
        go (TypeRpair TypeRunit s) t@TypeRscalar{} (OP_Pair OP_Unit x) = go s t x
        go s@TypeRscalar{} (TypeRpair TypeRunit t) x                   = OP_Pair OP_Unit <$> go s t x
        go _ _ _
          = error $ printf "could not coerce type `%s' to `%s'"
                      (show (typeOf (undefined::a)))
                      (show (typeOf (undefined::b)))

    primFun :: Elt r
            => PrimFun (a -> r)
            -> DelayedOpenExp env aenv a
            -> CodeGen (IR r)
    primFun f x =
      let
          -- The Accelerate language and its code generator are hyper-strict.
          -- However, we must not eagerly evaluate the arguments to logical
          -- operations (&&*) and (||*) so that they can short-circuit. Since we
          -- only have unary functions, this is a little tricky for us.
          --
          -- 'inl' and 'inr' attempt to destruct the incoming AST so that we can
          -- evaluate the left or right components of a pair individually. It
          -- should be noted that there are other cases which can evaluate to
          -- pairs; 'Constant', 'Let' and 'Var', for example, but these cases
          -- are (probably) not applicable in this context.
          --
          inl :: (Elt a, Elt b) => DelayedOpenExp env aenv (a,b) -> IROpenExp arch env aenv a
          inl (Tuple (SnocTup (SnocTup NilTup a) _)) = cvtE a
          inl t                                      = cvtE $ Prj (SuccTupIdx ZeroTupIdx) t

          inr :: (Elt a, Elt b) => DelayedOpenExp env aenv (a,b) -> IROpenExp arch env aenv b
          inr (Tuple (SnocTup _ b)) = cvtE b
          inr t                     = cvtE $ Prj ZeroTupIdx t
      in
      case f of
        PrimAdd t                 -> A.uncurry (A.add t)     =<< cvtE x
        PrimSub t                 -> A.uncurry (A.sub t)     =<< cvtE x
        PrimMul t                 -> A.uncurry (A.mul t)     =<< cvtE x
        PrimNeg t                 -> A.negate t              =<< cvtE x
        PrimAbs t                 -> A.abs t                 =<< cvtE x
        PrimSig t                 -> A.signum t              =<< cvtE x
        PrimQuot t                -> A.uncurry (A.quot t)    =<< cvtE x
        PrimRem t                 -> A.uncurry (A.rem t)     =<< cvtE x
        PrimQuotRem t             -> A.uncurry (A.quotRem t) =<< cvtE x
        PrimIDiv t                -> A.uncurry (A.idiv t)    =<< cvtE x
        PrimMod t                 -> A.uncurry (A.mod t)     =<< cvtE x
        PrimDivMod t              -> A.uncurry (A.divMod t)  =<< cvtE x
        PrimBAnd t                -> A.uncurry (A.band t)    =<< cvtE x
        PrimBOr t                 -> A.uncurry (A.bor t)     =<< cvtE x
        PrimBXor t                -> A.uncurry (A.xor t)     =<< cvtE x
        PrimBNot t                -> A.complement t          =<< cvtE x
        PrimBShiftL t             -> A.uncurry (A.shiftL t)  =<< cvtE x
        PrimBShiftR t             -> A.uncurry (A.shiftR t)  =<< cvtE x
        PrimBRotateL t            -> A.uncurry (A.rotateL t) =<< cvtE x
        PrimBRotateR t            -> A.uncurry (A.rotateR t) =<< cvtE x
        PrimPopCount t            -> A.popCount t            =<< cvtE x
        PrimCountLeadingZeros t   -> A.countLeadingZeros t   =<< cvtE x
        PrimCountTrailingZeros t  -> A.countTrailingZeros t  =<< cvtE x
        PrimFDiv t                -> A.uncurry (A.fdiv t)    =<< cvtE x
        PrimRecip t               -> A.recip t               =<< cvtE x
        PrimSin t                 -> A.sin t                 =<< cvtE x
        PrimCos t                 -> A.cos t                 =<< cvtE x
        PrimTan t                 -> A.tan t                 =<< cvtE x
        PrimSinh t                -> A.sinh t                =<< cvtE x
        PrimCosh t                -> A.cosh t                =<< cvtE x
        PrimTanh t                -> A.tanh t                =<< cvtE x
        PrimAsin t                -> A.asin t                =<< cvtE x
        PrimAcos t                -> A.acos t                =<< cvtE x
        PrimAtan t                -> A.atan t                =<< cvtE x
        PrimAsinh t               -> A.asinh t               =<< cvtE x
        PrimAcosh t               -> A.acosh t               =<< cvtE x
        PrimAtanh t               -> A.atanh t               =<< cvtE x
        PrimAtan2 t               -> A.uncurry (A.atan2 t)   =<< cvtE x
        PrimExpFloating t         -> A.exp t                 =<< cvtE x
        PrimFPow t                -> A.uncurry (A.fpow t)    =<< cvtE x
        PrimSqrt t                -> A.sqrt t                =<< cvtE x
        PrimLog t                 -> A.log t                 =<< cvtE x
        PrimLogBase t             -> A.uncurry (A.logBase t) =<< cvtE x
        PrimTruncate ta tb        -> A.truncate ta tb        =<< cvtE x
        PrimRound ta tb           -> A.round ta tb           =<< cvtE x
        PrimFloor ta tb           -> A.floor ta tb           =<< cvtE x
        PrimCeiling ta tb         -> A.ceiling ta tb         =<< cvtE x
        PrimIsNaN t               -> A.isNaN t               =<< cvtE x
        PrimIsInfinite t          -> A.isInfinite t          =<< cvtE x
        PrimLt t                  -> A.uncurry (A.lt t)      =<< cvtE x
        PrimGt t                  -> A.uncurry (A.gt t)      =<< cvtE x
        PrimLtEq t                -> A.uncurry (A.lte t)     =<< cvtE x
        PrimGtEq t                -> A.uncurry (A.gte t)     =<< cvtE x
        PrimEq t                  -> A.uncurry (A.eq t)      =<< cvtE x
        PrimNEq t                 -> A.uncurry (A.neq t)     =<< cvtE x
        PrimMax t                 -> A.uncurry (A.max t)     =<< cvtE x
        PrimMin t                 -> A.uncurry (A.min t)     =<< cvtE x
        PrimLAnd                  -> A.land (inl x) (inr x)  -- short circuit
        PrimLOr                   -> A.lor  (inl x) (inr x)  -- short circuit
        PrimLNot                  -> A.lnot                  =<< cvtE x
        PrimOrd                   -> A.ord                   =<< cvtE x
        PrimChr                   -> A.chr                   =<< cvtE x
        PrimBoolToInt             -> A.boolToInt             =<< cvtE x
        PrimFromIntegral ta tb    -> A.fromIntegral ta tb    =<< cvtE x
        PrimToFloating ta tb      -> A.toFloating ta tb      =<< cvtE x
          -- no missing patterns, whoo!


-- | Extract the head of an index
--
indexHead :: IR (sh :. sz) -> IR sz
indexHead (IR (OP_Pair _ sz)) = IR sz

-- | Extract the tail of an index
--
indexTail :: IR (sh :. sz) -> IR sh
indexTail (IR (OP_Pair sh _)) = IR sh

-- | Construct an index from the head and tail
--
indexCons :: IR sh -> IR sz -> IR (sh :. sz)
indexCons (IR sh) (IR sz) = IR (OP_Pair sh sz)


-- | Convert a multidimensional array index into a linear index
--
intOfIndex :: forall sh. Shape sh => IR sh -> IR sh -> CodeGen (IR Int)
intOfIndex (IR extent) (IR index) = cvt (eltType (undefined::sh)) extent index
  where
    cvt :: TupleType t -> Operands t -> Operands t -> CodeGen (IR Int)
    cvt TypeRunit OP_Unit OP_Unit
      = return $ IR (constant (eltType (undefined :: Int)) 0)

    cvt (TypeRpair tsh t) (OP_Pair sh sz) (OP_Pair ix i)
      | Just Refl <- matchTupleType t (eltType (undefined::Int))
      -- If we short-circuit the last dimension, we can avoid inserting
      -- a multiply by zero and add of the result.
      = case matchTupleType tsh (eltType (undefined::Z)) of
          Just Refl -> return (IR i)
          Nothing   -> do
            a <- cvt tsh sh ix
            b <- A.mul numType a (IR sz)
            c <- A.add numType b (IR i)
            return c

    cvt (TypeRscalar t) _ (op' t -> i)
      | Just Refl <- matchScalarType t (scalarType :: ScalarType Int)
      = return $ ir t i

    cvt _ _ _
      = $internalError "intOfIndex" "expected shape with Int components"


-- | Convert a linear index into into a multidimensional index
--
indexOfInt :: forall sh. Shape sh => IR sh -> IR Int -> CodeGen (IR sh)
indexOfInt (IR extent) index = IR <$> cvt (eltType (undefined::sh)) extent index
  where
    cvt :: TupleType t -> Operands t -> IR Int -> CodeGen (Operands t)
    cvt TypeRunit OP_Unit _
      = return OP_Unit

    cvt (TypeRpair tsh tsz) (OP_Pair sh sz) i
      | Just Refl <- matchTupleType tsz (eltType (undefined::Int))
      = do
           i'    <- A.quot integralType i (IR sz)
           -- If we assume the index is in range, there is no point computing
           -- the remainder of the highest dimension since (i < sz) must hold
           IR r  <- case matchTupleType tsh (eltType (undefined::Z)) of
                      Just Refl -> return i     -- TODO: in debug mode assert (i < sz)
                      Nothing   -> A.rem  integralType i (IR sz)
           sh'   <- cvt tsh sh i'
           return $ OP_Pair sh' r

    cvt (TypeRscalar t) _ (IR i)
      | Just Refl <- matchScalarType t (scalarType :: ScalarType Int)
      = return i

    cvt _ _ _
      = $internalError "indexOfInt" "expected shape with Int components"