{-# LANGUAGE GeneralizedNewtypeDeriving, FlexibleContexts, LambdaCase, TypeSynonymInstances, FlexibleInstances, MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.CodeGen.ImpGen
  ( -- * Entry Points
    compileProg

    -- * Pluggable Compiler
  , OpCompiler
  , ExpCompiler
  , CopyCompiler
  , BodyCompiler
  , Operations (..)
  , defaultOperations
  , Destination (..)
  , ValueDestination (..)
  , MemLocation (..)
  , MemEntry (..)
  , ScalarEntry (..)

    -- * Monadic Compiler Interface
  , ImpM
  , Env (envVtable, envDefaultSpace)
  , subImpM
  , subImpM_
  , emit
  , collect
  , comment
  , VarEntry (..)
  , ArrayEntry (..)

    -- * Lookups
  , lookupVar
  , lookupArray
  , arrayLocation
  , lookupMemory

    -- * Building Blocks
  , compileSubExp
  , compileSubExpOfType
  , compileSubExpTo
  , compilePrimExp
  , compileAlloc
  , subExpToDimSize
  , declaringLParams
  , declaringFParams
  , declaringVarEntry
  , declaringScope
  , declaringScopes
  , declaringPrimVar
  , declaringPrimVars
  , withPrimVar
  , everythingVolatile
  , compileBody
  , compileLoopBody
  , defCompileBody
  , compileStms
  , compileExp
  , defCompileExp
  , sliceArray
  , offsetArray
  , strideArray
  , fullyIndexArray
  , fullyIndexArray'
  , varIndex
  , Imp.dimSizeToExp
  , dimSizeToSubExp
  , destinationFromParam
  , destinationFromParams
  , destinationFromPattern
  , funcallTargets
  , copy
  , copyDWIM
  , copyDWIMDest
  , copyElementWise
  )
  where

import Control.Monad.RWS    hiding (mapM, forM)
import Control.Monad.State  hiding (mapM, forM)
import Control.Monad.Writer hiding (mapM, forM)
import Control.Monad.Except hiding (mapM, forM)
import qualified Control.Monad.Fail as Fail
import Data.Either
import Data.Traversable
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Maybe
import Data.List

import qualified Futhark.CodeGen.ImpCode as Imp
import Futhark.CodeGen.ImpCode
  (Count (..),
   Bytes, Elements,
   bytes, withElemType)
import Futhark.Representation.ExplicitMemory
import Futhark.Representation.SOACS (SOACS)
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Construct (fullSliceNum)
import Futhark.MonadFreshNames
import Futhark.Error
import Futhark.Util

-- | How to compile an 'Op'.
type OpCompiler lore op = Destination -> Op lore -> ImpM lore op ()

-- | How to compile a 'Body'.
type BodyCompiler lore op = Destination -> Body lore -> ImpM lore op ()

-- | How to compile an 'Exp'.
type ExpCompiler lore op = Destination -> Exp lore -> ImpM lore op ()

type CopyCompiler lore op = PrimType
                           -> MemLocation
                           -> MemLocation
                           -> Count Elements -- ^ Number of row elements of the source.
                           -> ImpM lore op ()

data Operations lore op = Operations { opsExpCompiler :: ExpCompiler lore op
                                     , opsOpCompiler :: OpCompiler lore op
                                     , opsBodyCompiler :: BodyCompiler lore op
                                     , opsCopyCompiler :: CopyCompiler lore op
                                     }

-- | An operations set for which the expression compiler always
-- returns 'CompileExp'.
defaultOperations :: (ExplicitMemorish lore, FreeIn op) =>
                     OpCompiler lore op -> Operations lore op
defaultOperations opc = Operations { opsExpCompiler = defCompileExp
                                   , opsOpCompiler = opc
                                   , opsBodyCompiler = defCompileBody
                                   , opsCopyCompiler = defaultCopy
                                   }

-- | When an array is declared, this is where it is stored.
data MemLocation = MemLocation { memLocationName :: VName
                               , memLocationShape :: [Imp.DimSize]
                               , memLocationIxFun :: IxFun.IxFun Imp.Exp
                               }
                   deriving (Eq, Show)

data ArrayEntry = ArrayEntry {
    entryArrayLocation :: MemLocation
  , entryArrayElemType :: PrimType
  }

entryArrayShape :: ArrayEntry -> [Imp.DimSize]
entryArrayShape = memLocationShape . entryArrayLocation

data MemEntry = MemEntry {
      entryMemSize  :: Imp.MemSize
    , entryMemSpace :: Imp.Space
  }

newtype ScalarEntry = ScalarEntry {
    entryScalarType    :: PrimType
  }

-- | Every non-scalar variable must be associated with an entry.
data VarEntry lore = ArrayVar (Maybe (Exp lore)) ArrayEntry
                   | ScalarVar (Maybe (Exp lore)) ScalarEntry
                   | MemVar (Maybe (Exp lore)) MemEntry

-- | When compiling an expression, this is a description of where the
-- result should end up.  The integer is a reference to the construct
-- that gave rise to this destination (for patterns, this will be the
-- tag of the first name in the pattern).  This can be used to make
-- the generated code easier to relate to the original code.
data Destination = Destination { destinationTag :: Maybe Int
                               , valueDestinations :: [ValueDestination] }
                    deriving (Show)

data ValueDestination = ScalarDestination VName
                      | ArrayElemDestination VName PrimType Imp.Space (Count Bytes)
                      | MemoryDestination VName
                      | ArrayDestination (Maybe MemLocation)
                        -- ^ The 'MemLocation' is 'Just' if a copy if
                        -- required.  If it is 'Nothing', then a
                        -- copy/assignment of a memory block somewhere
                        -- takes care of this array.
                      deriving (Show)

-- | If the given value destination if a 'ScalarDestination', return
-- the variable name.  Otherwise, 'Nothing'.
fromScalarDestination :: ValueDestination -> Maybe VName
fromScalarDestination (ScalarDestination name) = Just name
fromScalarDestination _                        = Nothing

data Env lore op = Env {
    envVtable :: M.Map VName (VarEntry lore)
  , envExpCompiler :: ExpCompiler lore op
  , envBodyCompiler :: BodyCompiler lore op
  , envOpCompiler :: OpCompiler lore op
  , envCopyCompiler :: CopyCompiler lore op
  , envDefaultSpace :: Imp.Space
  , envVolatility :: Imp.Volatility
  }

newEnv :: Operations lore op -> Imp.Space -> Env lore op
newEnv ops ds = Env { envVtable = M.empty
                    , envExpCompiler = opsExpCompiler ops
                    , envBodyCompiler = opsBodyCompiler ops
                    , envOpCompiler = opsOpCompiler ops
                    , envCopyCompiler = opsCopyCompiler ops
                    , envDefaultSpace = ds
                    , envVolatility = Imp.Nonvolatile
                    }

newtype ImpM lore op a = ImpM (RWST (Env lore op) (Imp.Code op) VNameSource (Either InternalError) a)
  deriving (Functor, Applicative, Monad,
            MonadState VNameSource,
            MonadReader (Env lore op),
            MonadWriter (Imp.Code op),
            MonadError InternalError)

instance Fail.MonadFail (ImpM lore op) where
  fail = error . ("ImpM.fail: "++)

instance MonadFreshNames (ImpM lore op) where
  getNameSource = get
  putNameSource = put


instance HasScope SOACS (ImpM lore op) where
  askScope = M.map (LetInfo . entryType) <$> asks envVtable
    where entryType (MemVar _ memEntry) =
            Mem (dimSizeToSubExp $ entryMemSize memEntry) (entryMemSpace memEntry)
          entryType (ArrayVar _ arrayEntry) =
            Array
            (entryArrayElemType arrayEntry)
            (Shape $ map dimSizeToSubExp $ entryArrayShape arrayEntry)
            NoUniqueness
          entryType (ScalarVar _ scalarEntry) =
            Prim $ entryScalarType scalarEntry

runImpM :: ImpM lore op a
        -> Operations lore op -> Imp.Space -> VNameSource
        -> Either InternalError (a, VNameSource, Imp.Code op)
runImpM (ImpM m) comp = runRWST m . newEnv comp

subImpM_ :: Operations lore' op' -> ImpM lore' op' a
         -> ImpM lore op (Imp.Code op')
subImpM_ ops m = snd <$> subImpM ops m

subImpM :: Operations lore' op' -> ImpM lore' op' a
        -> ImpM lore op (a, Imp.Code op')
subImpM ops (ImpM m) = do
  env <- ask
  src <- getNameSource
  case runRWST m env { envExpCompiler = opsExpCompiler ops
                     , envBodyCompiler = opsBodyCompiler ops
                     , envCopyCompiler = opsCopyCompiler ops
                     , envOpCompiler = opsOpCompiler ops
                     , envVtable = M.map scrubExps $ envVtable env
                     }
       src of
    Left err -> throwError err
    Right (x, src', code) -> do
      putNameSource src'
      return (x, code)
  where scrubExps (ArrayVar _ entry) = ArrayVar Nothing entry
        scrubExps (MemVar _ entry) = MemVar Nothing entry
        scrubExps (ScalarVar _ entry) = ScalarVar Nothing entry

-- | Execute a code generation action, returning the code that was
-- emitted.
collect :: ImpM lore op () -> ImpM lore op (Imp.Code op)
collect m = pass $ do
  ((), code) <- listen m
  return (code, const mempty)

collect' :: ImpM lore op a -> ImpM lore op (a, Imp.Code op)
collect' m = pass $ do
  (x, code) <- listen m
  return ((x, code), const mempty)

-- | Execute a code generation action, wrapping the generated code
-- within a 'Imp.Comment' with the given description.
comment :: String -> ImpM lore op () -> ImpM lore op ()
comment desc m = do code <- collect m
                    emit $ Imp.Comment desc code

-- | Emit some generated imperative code.
emit :: Imp.Code op -> ImpM lore op ()
emit = tell

compileProg :: (ExplicitMemorish lore, MonadFreshNames m) =>
               Operations lore op -> Imp.Space
            -> Prog lore -> m (Either InternalError (Imp.Functions op))
compileProg ops ds prog =
  modifyNameSource $ \src ->
  case mapAccumLM (compileFunDef ops ds) src (progFunctions prog) of
    Left err -> (Left err, src)
    Right (src', funs) -> (Right $ Imp.Functions funs, src')

compileInParam :: ExplicitMemorish lore =>
                  FParam lore -> ImpM lore op (Either Imp.Param ArrayDecl)
compileInParam fparam = case paramAttr fparam of
  MemPrim bt ->
    return $ Left $ Imp.ScalarParam name bt
  MemMem _ space ->
    return $ Left $ Imp.MemParam name space
  MemArray bt shape _ (ArrayIn mem ixfun) -> do
    shape' <- mapM subExpToDimSize $ shapeDims shape
    return $ Right $ ArrayDecl name bt $
      MemLocation mem shape' $ fmap compilePrimExp ixfun
  where name = paramName fparam

data ArrayDecl = ArrayDecl VName PrimType MemLocation

fparamSizes :: Typed attr => Param attr -> S.Set VName
fparamSizes fparam
  | Mem (Var size) _ <- paramType fparam = S.singleton size
  | otherwise = S.fromList $ subExpVars $ arrayDims $ paramType fparam

compileInParams :: ExplicitMemorish lore =>
                   [FParam lore] -> [EntryPointType]
                -> ImpM lore op ([Imp.Param], [ArrayDecl], [Imp.ExternalValue])
compileInParams params orig_epts = do
  let (ctx_params, val_params) =
        splitAt (length params - sum (map entryPointSize orig_epts)) params
  (inparams, arraydecls) <- partitionEithers <$> mapM compileInParam (ctx_params++val_params)
  let findArray x = find (isArrayDecl x) arraydecls
      sizes = mconcat $ map fparamSizes $ ctx_params++val_params

      summaries = M.fromList $ mapMaybe memSummary params
        where memSummary param
                | MemMem (Constant (IntValue (Int64Value size))) space <- paramAttr param =
                    Just (paramName param, (Imp.ConstSize size, space))
                | MemMem (Var size) space <- paramAttr param =
                    Just (paramName param, (Imp.VarSize size, space))
                | otherwise =
                    Nothing

      findMemInfo :: VName -> Maybe (Imp.MemSize, Space)
      findMemInfo = flip M.lookup summaries

      mkValueDesc fparam signedness =
        case (findArray $ paramName fparam, paramType fparam) of
          (Just (ArrayDecl _ bt (MemLocation mem shape _)), _) -> do
            (memsize, memspace) <- findMemInfo mem
            Just $ Imp.ArrayValue mem memsize memspace bt signedness shape
          (_, Prim bt)
            | paramName fparam `S.member` sizes ->
              Nothing
            | otherwise ->
              Just $ Imp.ScalarValue bt signedness $ paramName fparam
          _ ->
            Nothing

      mkExts (TypeOpaque desc n:epts) fparams =
        let (fparams',rest) = splitAt n fparams
        in Imp.OpaqueValue desc
           (mapMaybe (`mkValueDesc` Imp.TypeDirect) fparams') :
           mkExts epts rest
      mkExts (TypeUnsigned:epts) (fparam:fparams) =
        maybeToList (Imp.TransparentValue <$> mkValueDesc fparam Imp.TypeUnsigned) ++
        mkExts epts fparams
      mkExts (TypeDirect:epts) (fparam:fparams) =
        maybeToList (Imp.TransparentValue <$> mkValueDesc fparam Imp.TypeDirect) ++
        mkExts epts fparams
      mkExts _ _ = []

  return (inparams, arraydecls, mkExts orig_epts val_params)
  where isArrayDecl x (ArrayDecl y _ _) = x == y

compileOutParams :: ExplicitMemorish lore =>
                    [RetType lore] -> [EntryPointType]
                 -> ImpM lore op ([Imp.ExternalValue], [Imp.Param], Destination)
compileOutParams orig_rts orig_epts = do
  ((extvs, dests), (outparams,ctx_dests)) <-
    runWriterT $ evalStateT (mkExts orig_epts orig_rts) (M.empty, M.empty)
  let ctx_dests' = map snd $ sortOn fst $ M.toList ctx_dests
  return (extvs, outparams, Destination Nothing $ ctx_dests' <> dests)
  where imp = lift . lift

        mkExts (TypeOpaque desc n:epts) rts = do
          let (rts',rest) = splitAt n rts
          (evs, dests) <- unzip <$> zipWithM mkParam rts' (repeat Imp.TypeDirect)
          (more_values, more_dests) <- mkExts epts rest
          return (Imp.OpaqueValue desc evs : more_values,
                  dests ++ more_dests)
        mkExts (TypeUnsigned:epts) (rt:rts) = do
          (ev,dest) <- mkParam rt Imp.TypeUnsigned
          (more_values, more_dests) <- mkExts epts rts
          return (Imp.TransparentValue ev : more_values,
                  dest : more_dests)
        mkExts (TypeDirect:epts) (rt:rts) = do
          (ev,dest) <- mkParam rt Imp.TypeDirect
          (more_values, more_dests) <- mkExts epts rts
          return (Imp.TransparentValue ev : more_values,
                  dest : more_dests)
        mkExts _ _ = return ([], [])

        mkParam MemMem{} _ =
          compilerBugS "Functions may not explicitly return memory blocks."
        mkParam (MemPrim t) ept = do
          out <- imp $ newVName "scalar_out"
          tell ([Imp.ScalarParam out t], mempty)
          return (Imp.ScalarValue t ept out, ScalarDestination out)
        mkParam (MemArray t shape _ attr) ept = do
          space <- asks envDefaultSpace
          (memout, memsize) <- case attr of
            ReturnsNewBlock _ x x_size _ixfun -> do
              memout <- imp $ newVName "out_mem"
              sizeout <- ensureMemSizeOut x_size
              tell ([Imp.MemParam memout space],
                    M.singleton x $ MemoryDestination memout)
              return (memout, sizeout)
            ReturnsInBlock memout _ -> do
              memsize <- imp $ entryMemSize <$> lookupMemory memout
              return (memout, memsize)
          resultshape <- mapM inspectExtSize $ shapeDims shape
          return (Imp.ArrayValue memout memsize space t ept resultshape,
                  ArrayDestination Nothing)

        inspectExtSize (Ext x) = do
          (memseen,arrseen) <- get
          case M.lookup x arrseen of
            Nothing -> do
              out <- imp $ newVName "out_arrsize"
              tell ([Imp.ScalarParam out int32],
                    M.singleton x $ ScalarDestination out)
              put (memseen, M.insert x out arrseen)
              return $ Imp.VarSize out
            Just out ->
              return $ Imp.VarSize out
        inspectExtSize (Free se) =
          imp $ subExpToDimSize se

        -- | Return the name of the out-parameter for the memory size
        -- 'x', creating it if it does not already exist.
        ensureMemSizeOut (Ext x) = do
          (memseen, arrseen) <- get
          case M.lookup x memseen of
            Nothing -> do sizeout <- imp $ newVName "out_memsize"
                          tell ([Imp.ScalarParam sizeout int64],
                                M.singleton x $ ScalarDestination sizeout)
                          put (M.insert x sizeout memseen, arrseen)
                          return $ Imp.VarSize sizeout
            Just sizeout -> return $ Imp.VarSize sizeout
        ensureMemSizeOut (Free v) = imp $ subExpToDimSize v

compileFunDef :: ExplicitMemorish lore =>
                 Operations lore op -> Imp.Space
              -> VNameSource
              -> FunDef lore
              -> Either InternalError (VNameSource, (Name, Imp.Function op))
compileFunDef ops ds src (FunDef entry fname rettype params body) = do
  ((outparams, inparams, results, args), src', body') <-
    runImpM compile ops ds src
  return (src',
          (fname,
           Imp.Function (isJust entry) outparams inparams body' results args))
  where params_entry = maybe (replicate (length params) TypeDirect) fst entry
        ret_entry = maybe (replicate (length rettype) TypeDirect) snd entry
        compile = do
          (inparams, arraydecls, args) <- compileInParams params params_entry
          (results, outparams, dests) <- compileOutParams rettype ret_entry
          withFParams params $
            withArrays arraydecls $
            compileBody dests body
          return (outparams, inparams, results, args)

compileBody :: Destination -> Body lore -> ImpM lore op ()
compileBody dest body = do
  cb <- asks envBodyCompiler
  cb dest body

defCompileBody :: (ExplicitMemorish lore, FreeIn op) => Destination -> Body lore -> ImpM lore op ()
defCompileBody (Destination _ dest) (Body _ bnds ses) =
  compileStms (freeIn ses) (stmsToList bnds) $ zipWithM_ compileSubExpTo dest ses

compileLoopBody :: (ExplicitMemorish lore, FreeIn op) =>
                   [VName] -> Body lore -> ImpM lore op (Imp.Code op)
compileLoopBody mergenames (Body _ bnds ses) = do
  -- We cannot write the results to the merge parameters immediately,
  -- as some of the results may actually *be* merge parameters, and
  -- would thus be clobbered.  Therefore, we first copy to new
  -- variables mirroring the merge parameters, and then copy this
  -- buffer to the merge parameters.  This is efficient, because the
  -- operations are all scalar operations.
  tmpnames <- mapM (newVName . (++"_tmp") . baseString) mergenames
  collect $ compileStms (freeIn ses) (stmsToList bnds) $ do
    copy_to_merge_params <- forM (zip3 mergenames tmpnames ses) $ \(d,tmp,se) ->
      subExpType se >>= \case
        Prim bt  -> do
          se' <- compileSubExp se
          emit $ Imp.DeclareScalar tmp bt
          emit $ Imp.SetScalar tmp se'
          return $ emit $ Imp.SetScalar d $ Imp.var tmp bt
        Mem _ space | Var v <- se -> do
          emit $ Imp.DeclareMem tmp space
          emit $ Imp.SetMem tmp v space
          return $ emit $ Imp.SetMem d tmp space
        _ -> return $ return ()
    sequence_ copy_to_merge_params

compileStms :: (ExplicitMemorish lore, FreeIn op) =>
               Names -> [Stm lore] -> ImpM lore op () -> ImpM lore op ()
compileStms alive_after_stms all_stms m =
  -- We keep track of any memory blocks produced by the statements,
  -- and after the last time that memory block is used, we insert a
  -- Free.  This is very conservative, but can cut down on lifetimes
  -- in some cases.
  void $ compileStms' mempty all_stms
  where compileStms' allocs (Let pat _ e:bs) =
          declaringVars (Just e) (patternElements pat) $ do
          dest <- destinationFromPattern pat

          e_code <- collect $ compileExp dest e
          (live_after, bs_code) <- collect' $ compileStms' (patternAllocs pat <> allocs) bs
          let dies_here v = not (v `S.member` live_after) &&
                            v `S.member` freeIn e_code
              to_free = S.filter (dies_here . fst) allocs

          emit e_code
          mapM_ (emit . uncurry Imp.Free) to_free
          emit bs_code

          return $ freeIn e_code <> live_after
        compileStms' _ [] = do
          code <- collect m
          emit code
          return $ freeIn code <> alive_after_stms

        patternAllocs = S.fromList . mapMaybe isMemPatElem . patternElements
        isMemPatElem pe = case patElemType pe of
                            Mem _ space -> Just (patElemName pe, space)
                            _           -> Nothing

compileExp :: Destination -> Exp lore -> ImpM lore op ()
compileExp targets e = do
  ec <- asks envExpCompiler
  ec targets e

defCompileExp :: (ExplicitMemorish lore, FreeIn op) =>
                 Destination -> Exp lore -> ImpM lore op ()

defCompileExp dest (If cond tbranch fbranch _) = do
  cond' <- compileSubExp cond
  tcode <- collect $ compileBody dest tbranch
  fcode <- collect $ compileBody dest fbranch
  emit $ Imp.If cond' tcode fcode

defCompileExp dest (Apply fname args _ _) = do
  targets <- funcallTargets dest
  args' <- catMaybes <$> mapM compileArg args
  emit $ Imp.Call targets fname args'
  where compileArg (se, _) = do
          t <- subExpType se
          case (se, t) of
            (_, Prim pt)    -> return $ Just $ Imp.ExpArg $ compileSubExpOfType pt se
            (Var v, Mem{}) -> return $ Just $ Imp.MemArg v
            _              -> return Nothing

defCompileExp targets (BasicOp op) = defCompileBasicOp targets op

defCompileExp (Destination _ dest) (DoLoop ctx val form body) =
  declaringFParams mergepat $ do
    forM_ merge $ \(p, se) -> do
      na <- subExpNotArray se
      when na $
        copyDWIM (paramName p) [] se []
    (bindForm, emitForm) <-
      case form of
        ForLoop i it bound loopvars -> do
          bound' <- compileSubExp bound
          let setLoopParam (p,a)
                | Prim _ <- paramType p =
                    copyDWIM (paramName p) [] (Var a) [varIndex i]
                | otherwise =
                    return ()

          let emitForm body' = do
                set_loop_params <- collect $ mapM_ setLoopParam loopvars
                emit $ Imp.For i it bound' $ set_loop_params<>body'
          return (declaringLParams (map fst loopvars) .
                  declaringLoopVar i it,
                  emitForm)
        WhileLoop cond ->
          return (id, emit . Imp.While (Imp.var cond Bool))

    bindForm $ do
      body' <- compileLoopBody mergenames body
      emitForm body'
    zipWithM_ compileSubExpTo dest $ map (Var . paramName . fst) merge
    where merge = ctx ++ val
          mergepat = map fst merge
          mergenames = map paramName mergepat

defCompileExp dest (Op op) = do
  opc <- asks envOpCompiler
  opc dest op

defCompileBasicOp :: Destination -> BasicOp lore -> ImpM lore op ()

defCompileBasicOp (Destination _ [target]) (SubExp se) =
  compileSubExpTo target se

defCompileBasicOp (Destination _ [target]) (Opaque se) =
  compileSubExpTo target se

defCompileBasicOp (Destination _ [target]) (UnOp op e) = do
  e' <- compileSubExp e
  writeExp target $ Imp.UnOpExp op e'

defCompileBasicOp (Destination _ [target]) (ConvOp conv e) = do
  e' <- compileSubExp e
  writeExp target $ Imp.ConvOpExp conv e'

defCompileBasicOp (Destination _ [target]) (BinOp bop x y) = do
  x' <- compileSubExp x
  y' <- compileSubExp y
  writeExp target $ Imp.BinOpExp bop x' y'

defCompileBasicOp (Destination _ [target]) (CmpOp bop x y) = do
  x' <- compileSubExp x
  y' <- compileSubExp y
  writeExp target $ Imp.CmpOpExp bop x' y'

defCompileBasicOp (Destination _ [_]) (Assert e msg loc) = do
  e' <- compileSubExp e
  msg' <- traverse compileSubExp msg
  emit $ Imp.Assert e' msg' loc

defCompileBasicOp (Destination _ [target]) (Index src slice)
  | Just idxs <- sliceIndices slice =
      copyDWIMDest target [] (Var src) $ map (compileSubExpOfType int32) idxs

defCompileBasicOp _ Index{} =
  return ()

defCompileBasicOp (Destination _ [ArrayDestination (Just memloc)]) (Update _ slice se)
  | MemLocation mem shape ixfun <- memloc = do
    bt <- elemType <$> subExpType se
    target' <-
      case sliceIndices slice of
        Just is -> do
          (_, space, elemOffset) <-
            fullyIndexArray'
            (MemLocation mem shape ixfun)
            (map (compileSubExpOfType int32) is)
            bt
          return $ ArrayElemDestination mem bt space elemOffset
        Nothing ->
          let memdest = sliceArray (MemLocation mem shape ixfun) $
                        map (fmap (compileSubExpOfType int32)) slice
          in return $ ArrayDestination $ Just memdest

    copyDWIMDest target' [] se []

defCompileBasicOp (Destination _ [dest]) (Replicate (Shape ds) se) = do
  is <- replicateM (length ds) (newVName "i")
  ds' <- mapM compileSubExp ds
  declaringLoopVars Int32 is $ do
    copy_elem <- collect $ copyDWIMDest dest (map varIndex is) se []
    emit $ foldl (.) id (zipWith (`Imp.For` Int32) is ds') copy_elem

defCompileBasicOp (Destination _ [_]) Scratch{} =
  return ()

defCompileBasicOp (Destination _ [dest]) (Iota n e s et) = do
  i <- newVName "i"
  x <- newVName "x"
  n' <- compileSubExp n
  e' <- compileSubExp e
  s' <- compileSubExp s
  emit $ Imp.DeclareScalar x $ IntType et
  let i' = ConvOpExp (SExt Int32 et) $ Imp.var i $ IntType Int32
  declaringLoopVar i Int32 $ withPrimVar x (IntType et) $
    emit =<< (Imp.For i Int32 n' <$>
              collect (do emit $ Imp.SetScalar x $ e' + i' * s'
                          copyDWIMDest dest [varIndex i] (Var x) []))

defCompileBasicOp (Destination _ [target]) (Copy src) =
  compileSubExpTo target $ Var src

defCompileBasicOp (Destination _ [target]) (Manifest _ src) =
  compileSubExpTo target $ Var src

defCompileBasicOp
  (Destination _ [ArrayDestination (Just (MemLocation destmem destshape destixfun))])
  (Concat i x ys _) = do
    xtype <- lookupType x
    offs_glb <- newVName "tmp_offs"
    withPrimVar offs_glb int32 $ do
      emit $ Imp.DeclareScalar offs_glb int32
      emit $ Imp.SetScalar offs_glb 0
      let perm = [i] ++ [0..i-1] ++ [i+1..length destshape-1]
          invperm = rearrangeInverse perm
          destloc = MemLocation destmem destshape
                    (IxFun.permute (IxFun.offsetIndex (IxFun.permute destixfun perm) $
                                    varIndex offs_glb)
                     invperm)

      forM_ (x:ys) $ \y -> do
          yentry <- lookupArray y
          let srcloc = entryArrayLocation yentry
              rows = case drop i $ entryArrayShape yentry of
                      []  -> error $ "defCompileBasicOp Concat: empty array shape for " ++ pretty y
                      r:_ -> innerExp $ Imp.dimSizeToExp r
          copy (elemType xtype) destloc srcloc $ arrayOuterSize yentry
          emit $ Imp.SetScalar offs_glb $ Imp.var offs_glb int32 + rows

defCompileBasicOp (Destination _ [dest]) (ArrayLit es _)
  | ArrayDestination (Just dest_mem) <- dest,
    Just vs@(v:_) <- mapM isLiteral es = do
      dest_space <- entryMemSpace <$> lookupMemory (memLocationName dest_mem)
      let t = primValueType v
      static_array <- newVName "static_array"
      emit $ Imp.DeclareArray static_array dest_space t vs
      let static_src = MemLocation static_array [Imp.ConstSize $ fromIntegral $ length es] $
                       IxFun.iota [fromIntegral $ length es]
          num_bytes = Imp.ConstSize $ fromIntegral (length es) * primByteSize t
          entry = MemVar Nothing $ MemEntry num_bytes dest_space
      local (insertInVtable static_array entry) $
        copy t dest_mem static_src $ fromIntegral $ length es
  | otherwise =
    forM_ (zip [0..] es) $ \(i,e) ->
      copyDWIMDest dest [constIndex i] e []

  where isLiteral (Constant v) = Just v
        isLiteral _ = Nothing

defCompileBasicOp _ Rearrange{} =
  return ()

defCompileBasicOp _ Rotate{} =
  return ()

defCompileBasicOp _ Reshape{} =
  return ()

defCompileBasicOp _ Repeat{} =
  return ()

defCompileBasicOp (Destination _ dests) (Partition n flags value_arrs)
  | (sizedests, arrdest) <- splitAt n dests,
    Just sizenames <- mapM fromScalarDestination sizedests,
    Just destlocs <- mapM arrDestLoc arrdest = do
  i <- newVName "i"
  declaringLoopVar i Int32 $ do
    outer_dim <- compileSubExp =<< (arraySize 0 <$> lookupType flags)
    -- We will use 'i' to index the flag array and the value array.
    -- Note that they have the same outer size ('outer_dim').
    read_flags_i <- readFromArray flags [varIndex i]

    -- First, for each of the 'n' output arrays, we compute the final
    -- size.  This is done by iterating through the flag array, but
    -- first we declare scalars to hold the size.  We do this by
    -- creating a mapping from equivalence classes to the name of the
    -- scalar holding the size.
    let sizes = M.fromList $ zip [0..n-1] sizenames

    -- We initialise ecah size to zero.
    forM_ sizenames $ \sizename ->
      emit $ Imp.SetScalar sizename 0

    -- Now iterate across the flag array, storing each element in
    -- 'eqclass', then comparing it to the known classes and increasing
    -- the appropriate size variable.
    eqclass <- newVName "eqclass"
    emit $ Imp.DeclareScalar eqclass int32
    let mkSizeLoopBody code c sizevar =
          Imp.If (Imp.CmpOpExp (CmpEq int32) (Imp.var eqclass int32) (fromIntegral c))
          (Imp.SetScalar sizevar $ Imp.var sizevar int32 + 1)
          code
        sizeLoopBody = M.foldlWithKey' mkSizeLoopBody Imp.Skip sizes
    emit $ Imp.For i Int32 outer_dim $
      Imp.SetScalar eqclass read_flags_i <>
      sizeLoopBody

    -- We can now compute the starting offsets of each of the
    -- partitions, creating a map from equivalence class to its
    -- corresponding offset.
    offsets <- flip evalStateT 0 $ forM sizes $ \size -> do
      cur_offset <- get
      partition_offset <- lift $ newVName "partition_offset"
      lift $ emit $ Imp.DeclareScalar partition_offset int32
      lift $ emit $ Imp.SetScalar partition_offset cur_offset
      put $ Imp.var partition_offset int32 + Imp.var size int32
      return partition_offset

    -- We create the memory location we use when writing a result
    -- element.  This is basically the index function of 'destloc', but
    -- with a dynamic offset, stored in 'partition_cur_offset'.
    partition_cur_offset <- newVName "partition_cur_offset"
    emit $ Imp.DeclareScalar partition_cur_offset int32

    -- Finally, we iterate through the data array and flag array in
    -- parallel, and put each element where it is supposed to go.  Note
    -- that after writing to a partition, we increase the corresponding
    -- offset.
    ets <- mapM (fmap elemType . lookupType) value_arrs
    srclocs <- mapM arrayLocation value_arrs
    copy_elements <- forM (zip3 destlocs ets srclocs) $ \(destloc,et,srcloc) ->
      copyArrayDWIM et
      destloc [varIndex partition_cur_offset]
      srcloc [varIndex i]
    let mkWriteLoopBody code c offsetvar =
          Imp.If (Imp.CmpOpExp (CmpEq int32) (Imp.var eqclass int32) (fromIntegral c))
          (Imp.SetScalar partition_cur_offset
             (Imp.var offsetvar int32)
           <>
           mconcat copy_elements
           <>
           Imp.SetScalar offsetvar
             (Imp.var offsetvar int32 + 1))
          code
        writeLoopBody = M.foldlWithKey' mkWriteLoopBody Imp.Skip offsets
    emit $ Imp.For i Int32 outer_dim $
      Imp.SetScalar eqclass read_flags_i <>
      writeLoopBody
    return ()
  where arrDestLoc (ArrayDestination destloc) = destloc
        arrDestLoc _ = Nothing

defCompileBasicOp (Destination _ []) _ = return () -- No arms, no cake.

defCompileBasicOp target e =
  compilerBugS $ "ImpGen.defCompileBasicOp: Invalid target\n  " ++
  show target ++ "\nfor expression\n  " ++ pretty e

writeExp :: ValueDestination -> Imp.Exp -> ImpM lore op ()
writeExp (ScalarDestination target) e =
  emit $ Imp.SetScalar target e
writeExp (ArrayElemDestination destmem bt space elemoffset) e = do
  vol <- asks envVolatility
  emit $ Imp.Write destmem elemoffset bt space vol e
writeExp target e =
  compilerBugS $ "Cannot write " ++ pretty e ++ " to " ++ show target

insertInVtable :: VName -> VarEntry lore -> Env lore op -> Env lore op
insertInVtable name entry env =
  env { envVtable = M.insert name entry $ envVtable env }

withArray :: ArrayDecl -> ImpM lore op a -> ImpM lore op a
withArray (ArrayDecl name bt location) m = do
  let entry = ArrayVar Nothing ArrayEntry
              { entryArrayLocation = location
              , entryArrayElemType = bt
              }
  local (insertInVtable name entry) m

withArrays :: [ArrayDecl] -> ImpM lore op a -> ImpM lore op a
withArrays = flip $ foldr withArray

-- | Like 'declaringFParams', but does not create new declarations.
withFParams :: ExplicitMemorish lore => [FParam lore] -> ImpM lore op a -> ImpM lore op a
withFParams = flip $ foldr withFParam
  where withFParam fparam m = do
          entry <- memBoundToVarEntry Nothing $ noUniquenessReturns $ paramAttr fparam
          local (insertInVtable (paramName fparam) entry) m

declaringVars :: ExplicitMemorish lore =>
                 Maybe (Exp lore) -> [PatElem lore] -> ImpM lore op a -> ImpM lore op a
declaringVars e = flip $ foldr declaringVar
  where declaringVar = declaringScope e . scopeOfPatElem

declaringFParams :: ExplicitMemorish lore => [FParam lore] -> ImpM lore op a -> ImpM lore op a
declaringFParams = declaringScope Nothing . scopeOfFParams

declaringLParams :: ExplicitMemorish lore => [LParam lore] -> ImpM lore op a -> ImpM lore op a
declaringLParams = declaringScope Nothing . scopeOfLParams

declaringVarEntry :: VName -> VarEntry lore -> ImpM lore op a -> ImpM lore op a
declaringVarEntry name entry m = do
  case entry of
    MemVar _ entry' ->
      emit $ Imp.DeclareMem name $ entryMemSpace entry'
    ScalarVar _ entry' ->
      emit $ Imp.DeclareScalar name $ entryScalarType entry'
    ArrayVar _ _ ->
      return ()
  local (insertInVtable name entry) m

declaringPrimVar :: VName -> PrimType -> ImpM lore op a -> ImpM lore op a
declaringPrimVar name bt =
  declaringVarEntry name $ ScalarVar Nothing $ ScalarEntry bt

declaringPrimVars :: [(VName,PrimType)] -> ImpM lore op a -> ImpM lore op a
declaringPrimVars = flip $ foldr (uncurry declaringPrimVar)

memBoundToVarEntry :: Maybe (Exp lore) -> MemBound NoUniqueness
                   -> ImpM lore op (VarEntry lore)
memBoundToVarEntry e (MemPrim bt) =
  return $ ScalarVar e ScalarEntry { entryScalarType = bt }
memBoundToVarEntry e (MemMem size space) = do
  size' <- subExpToDimSize size
  return $ MemVar e MemEntry { entryMemSize = size'
                             , entryMemSpace = space
                             }
memBoundToVarEntry e (MemArray bt shape _ (ArrayIn mem ixfun)) = do
  shape' <- mapM subExpToDimSize $ shapeDims shape
  let location = MemLocation mem shape' $ fmap compilePrimExp ixfun
  return $ ArrayVar e ArrayEntry { entryArrayLocation = location
                                 , entryArrayElemType = bt
                                 }

declaringName :: Maybe (Exp lore) -> VName -> NameInfo ExplicitMemory
              -> ImpM lore op a -> ImpM lore op a
declaringName e name info m = do
  entry <- memBoundToVarEntry e $ infoAttr info
  declaringVarEntry name entry m
  where infoAttr (LetInfo attr) = attr
        infoAttr (FParamInfo attr) = noUniquenessReturns attr
        infoAttr (LParamInfo attr) = attr
        infoAttr (IndexInfo it) = MemPrim $ IntType it

declaringScope :: Maybe (Exp lore) -> Scope ExplicitMemory -> ImpM lore op a -> ImpM lore op a
declaringScope e scope m = foldr (uncurry $ declaringName e) m $ M.toList scope

declaringScopes :: [(Maybe (Exp lore), Scope ExplicitMemory)] -> ImpM lore op a -> ImpM lore op a
declaringScopes es_and_scopes m = foldr (uncurry declaringScope) m es_and_scopes

withPrimVar :: VName -> PrimType -> ImpM lore op a -> ImpM lore op a
withPrimVar name bt =
  local (insertInVtable name $ ScalarVar Nothing $ ScalarEntry bt)

declaringLoopVars :: IntType -> [VName] -> ImpM lore op a -> ImpM lore op a
declaringLoopVars it = flip $ foldr (`declaringLoopVar` it)

declaringLoopVar :: VName -> IntType -> ImpM lore op a -> ImpM lore op a
declaringLoopVar name it =
  withPrimVar name $ IntType it

everythingVolatile :: ImpM lore op a -> ImpM lore op a
everythingVolatile = local $ \env -> env { envVolatility = Imp.Volatile }

-- | Remove the array targets.
funcallTargets :: Destination -> ImpM lore op [VName]
funcallTargets (Destination _ dests) =
  concat <$> mapM funcallTarget dests
  where funcallTarget (ScalarDestination name) =
          return [name]
        funcallTarget ArrayElemDestination{} =
          compilerBugS "Cannot put scalar function return in-place yet." -- FIXME
        funcallTarget (ArrayDestination _) =
          return []
        funcallTarget (MemoryDestination name) =
          return [name]

subExpToDimSize :: SubExp -> ImpM lore op Imp.DimSize
subExpToDimSize (Var v) =
  return $ Imp.VarSize v
subExpToDimSize (Constant (IntValue (Int64Value i))) =
  return $ Imp.ConstSize $ fromIntegral i
subExpToDimSize (Constant (IntValue (Int32Value i))) =
  return $ Imp.ConstSize $ fromIntegral i
subExpToDimSize Constant{} =
  compilerBugS "Size subexp is not an int32 or int64 constant."

compileSubExpTo :: ValueDestination -> SubExp -> ImpM lore op ()
compileSubExpTo dest se = copyDWIMDest dest [] se []

compileSubExp :: SubExp -> ImpM lore op Imp.Exp
compileSubExp (Constant v) =
  return $ Imp.ValueExp v
compileSubExp (Var v) = do
  t <- lookupType v
  case t of
    Prim pt -> return $ Imp.var v pt
    _       -> compilerBugS $ "compileSubExp: SubExp is not a primitive type: " ++ pretty v

compileSubExpOfType :: PrimType -> SubExp -> Imp.Exp
compileSubExpOfType _ (Constant v) = Imp.ValueExp v
compileSubExpOfType t (Var v) = Imp.var v t

compilePrimExp :: PrimExp VName -> Imp.Exp
compilePrimExp = fmap Imp.ScalarVar

varIndex :: VName -> Imp.Exp
varIndex name = LeafExp (Imp.ScalarVar name) int32

constIndex :: Int -> Imp.Exp
constIndex = fromIntegral

lookupVar :: VName -> ImpM lore op (VarEntry lore)
lookupVar name = do
  res <- asks $ M.lookup name . envVtable
  case res of
    Just entry -> return entry
    _ -> compilerBugS $ "Unknown variable: " ++ pretty name

lookupArray :: VName -> ImpM lore op ArrayEntry
lookupArray name = do
  res <- lookupVar name
  case res of
    ArrayVar _ entry -> return entry
    _                -> compilerBugS $ "ImpGen.lookupArray: not an array: " ++ pretty name

arrayLocation :: VName -> ImpM lore op MemLocation
arrayLocation name = entryArrayLocation <$> lookupArray name

lookupMemory :: VName -> ImpM lore op MemEntry
lookupMemory name = do
  res <- lookupVar name
  case res of
    MemVar _ entry -> return entry
    _              -> compilerBugS $ "Unknown memory block: " ++ pretty name

destinationFromParam :: Param (MemBound u) -> ImpM lore op ValueDestination
destinationFromParam param
  | MemArray _ shape _ (ArrayIn mem ixfun) <- paramAttr param = do
      let dims = shapeDims shape
      memloc <- MemLocation mem <$> mapM subExpToDimSize dims <*>
                pure (fmap compilePrimExp ixfun)
      return $ ArrayDestination $ Just memloc
  | otherwise =
      return $ ScalarDestination $ paramName param

destinationFromParams :: [Param (MemBound u)] -> ImpM lore op Destination
destinationFromParams ps = fmap (Destination $ baseTag . paramName <$> maybeHead ps) . mapM destinationFromParam $ ps

destinationFromPattern :: ExplicitMemorish lore => Pattern lore -> ImpM lore op Destination
destinationFromPattern pat = fmap (Destination (baseTag <$> maybeHead (patternNames pat))) . mapM inspect $
                             patternElements pat
  where ctx_names = patternContextNames pat
        inspect patElem = do
          let name = patElemName patElem
          entry <- lookupVar name
          case entry of
            ArrayVar _ (ArrayEntry (MemLocation mem shape ixfun) _) ->
              return $ ArrayDestination $
              if mem `elem` ctx_names
              then Nothing
              else Just $ MemLocation mem shape ixfun
            MemVar{} ->
              return $ MemoryDestination name

            ScalarVar{} ->
              return $ ScalarDestination name

fullyIndexArray :: VName -> [Imp.Exp]
                -> ImpM lore op (VName, Imp.Space, Count Bytes)
fullyIndexArray name indices = do
  arr <- lookupArray name
  fullyIndexArray' (entryArrayLocation arr) indices $ entryArrayElemType arr

fullyIndexArray' :: MemLocation -> [Imp.Exp] -> PrimType
                 -> ImpM lore op (VName, Imp.Space, Count Bytes)
fullyIndexArray' (MemLocation mem _ ixfun) indices bt = do
  space <- entryMemSpace <$> lookupMemory mem
  return (mem, space,
          bytes $ IxFun.index ixfun indices $ primByteSize bt)

readFromArray :: VName -> [Imp.Exp]
              -> ImpM lore op Imp.Exp
readFromArray name indices = do
  arr <- lookupArray name
  (mem, space, i) <-
    fullyIndexArray' (entryArrayLocation arr) indices $ entryArrayElemType arr
  vol <- asks envVolatility
  return $ Imp.index mem i (entryArrayElemType arr) space vol

sliceArray :: MemLocation
           -> Slice Imp.Exp
           -> MemLocation
sliceArray (MemLocation mem shape ixfun) slice =
  MemLocation mem (update shape slice) $ IxFun.slice ixfun slice
  where update (d:ds) (DimSlice{}:is) = d : update ds is
        update (_:ds) (DimFix{}:is) = update ds is
        update _      _               = []

offsetArray :: MemLocation
            -> Imp.Exp
            -> MemLocation
offsetArray (MemLocation mem shape ixfun) offset =
  MemLocation mem shape $ IxFun.offsetIndex ixfun offset

strideArray :: MemLocation
            -> Imp.Exp
            -> MemLocation
strideArray (MemLocation mem shape ixfun) stride =
  MemLocation mem shape $ IxFun.strideIndex ixfun stride

subExpNotArray :: SubExp -> ImpM lore op Bool
subExpNotArray se = subExpType se >>= \case
  Array {} -> return False
  _        -> return True

arrayOuterSize :: ArrayEntry -> Count Elements
arrayOuterSize = arrayDimSize 0

arrayDimSize :: Int -> ArrayEntry -> Count Elements
arrayDimSize i =
  product . map Imp.dimSizeToExp . take 1 . drop i . entryArrayShape

-- More complicated read/write operations that use index functions.

copy :: CopyCompiler lore op
copy bt dest src n = do
  cc <- asks envCopyCompiler
  cc bt dest src n

-- | Use an 'Imp.Copy' if possible, otherwise 'copyElementWise'.
defaultCopy :: CopyCompiler lore op
defaultCopy bt dest src n
  | ixFunMatchesInnerShape
      (Shape $ map dimSizeToExp destshape) destIxFun,
    ixFunMatchesInnerShape
      (Shape $ map dimSizeToExp srcshape) srcIxFun,
    Just destoffset <-
      IxFun.linearWithOffset destIxFun bt_size,
    Just srcoffset  <-
      IxFun.linearWithOffset srcIxFun bt_size = do
        srcspace <- entryMemSpace <$> lookupMemory srcmem
        destspace <- entryMemSpace <$> lookupMemory destmem
        emit $ Imp.Copy
          destmem (bytes destoffset) destspace
          srcmem (bytes srcoffset) srcspace $
          (n * row_size) `withElemType` bt
  | otherwise =
      copyElementWise bt dest src n
  where bt_size = primByteSize bt
        row_size = product $ map Imp.dimSizeToExp $ drop 1 srcshape
        MemLocation destmem destshape destIxFun = dest
        MemLocation srcmem srcshape srcIxFun = src

copyElementWise :: CopyCompiler lore op
copyElementWise bt (MemLocation destmem _ destIxFun) (MemLocation srcmem srcshape srcIxFun) n = do
    is <- replicateM (IxFun.rank destIxFun) (newVName "i")
    declaringLoopVars Int32 is $ do
      let ivars = map varIndex is
          destidx = IxFun.index destIxFun ivars bt_size
          srcidx = IxFun.index srcIxFun ivars bt_size
          bounds = map innerExp $ n : drop 1 (map Imp.dimSizeToExp srcshape)
      srcspace <- entryMemSpace <$> lookupMemory srcmem
      destspace <- entryMemSpace <$> lookupMemory destmem
      vol <- asks envVolatility
      emit $ foldl (.) id (zipWith (`Imp.For` Int32) is bounds) $
        Imp.Write destmem (bytes destidx) bt destspace vol $
        Imp.index srcmem (bytes srcidx) bt srcspace vol
  where bt_size = primByteSize bt

-- | Copy from here to there; both destination and source may be
-- indexeded.
copyArrayDWIM :: PrimType
              -> MemLocation -> [Imp.Exp]
              -> MemLocation -> [Imp.Exp]
              -> ImpM lore op (Imp.Code op)
copyArrayDWIM bt
  destlocation@(MemLocation _ destshape dest_ixfun) destis
  srclocation@(MemLocation _ srcshape src_ixfun) srcis

  | length srcis == length srcshape, length destis == length destshape = do
  (targetmem, destspace, targetoffset) <-
    fullyIndexArray' destlocation destis bt
  (srcmem, srcspace, srcoffset) <-
    fullyIndexArray' srclocation srcis bt
  vol <- asks envVolatility
  return $ Imp.Write targetmem targetoffset bt destspace vol $
    Imp.index srcmem srcoffset bt srcspace vol

  | otherwise = do
      let destlocation' =
            sliceArray destlocation $
            fullSliceNum (IxFun.shape dest_ixfun) $ map DimFix destis
          srclocation'  =
            sliceArray srclocation $
            fullSliceNum (IxFun.shape src_ixfun) $ map DimFix srcis
      if destlocation' == srclocation'
        then return mempty -- Copy would be no-op.
        else collect $ copy bt destlocation' srclocation' $
             product $ map Imp.dimSizeToExp $
             take 1 $ drop (length srcis) srcshape

-- | Like 'copyDWIM', but the target is a 'ValueDestination'
-- instead of a variable name.
copyDWIMDest :: ValueDestination -> [Imp.Exp] -> SubExp -> [Imp.Exp]
             -> ImpM lore op ()

copyDWIMDest _ _ (Constant v) (_:_) =
  compilerBugS $
  unwords ["copyDWIMDest: constant source", pretty v, "cannot be indexed."]
copyDWIMDest dest dest_is (Constant v) [] =
  case dest of
  ScalarDestination name ->
    emit $ Imp.SetScalar name $ Imp.ValueExp v
  ArrayElemDestination dest_mem _ dest_space dest_i -> do
    vol <- asks envVolatility
    emit $ Imp.Write dest_mem dest_i bt dest_space vol $ Imp.ValueExp v
  MemoryDestination{} ->
    compilerBugS $
    unwords ["copyDWIMDest: constant source", pretty v, "cannot be written to memory destination."]
  ArrayDestination (Just dest_loc) -> do
    (dest_mem, dest_space, dest_i) <-
      fullyIndexArray' dest_loc dest_is bt
    vol <- asks envVolatility
    emit $ Imp.Write dest_mem dest_i bt dest_space vol $ Imp.ValueExp v
  ArrayDestination Nothing ->
    compilerBugS "copyDWIMDest: ArrayDestination Nothing"
  where bt = primValueType v

copyDWIMDest dest dest_is (Var src) src_is = do
  src_entry <- lookupVar src
  case (dest, src_entry) of
    (MemoryDestination mem, MemVar _ (MemEntry _ space)) ->
      emit $ Imp.SetMem mem src space

    (MemoryDestination{}, _) ->
      compilerBugS $
      unwords ["copyDWIMDest: cannot write", pretty src, "to memory destination."]

    (_, MemVar{}) ->
      compilerBugS $
      unwords ["copyDWIMDest: source", pretty src, "is a memory block."]

    (_, ScalarVar _ (ScalarEntry _)) | not $ null src_is ->
      compilerBugS $
      unwords ["copyDWIMDest: prim-typed source", pretty src, "with nonzero indices."]


    (ScalarDestination name, _) | not $ null dest_is ->
      compilerBugS $
      unwords ["copyDWIMDest: prim-typed target", pretty name, "with nonzero indices."]

    (ScalarDestination name, ScalarVar _ (ScalarEntry pt)) ->
      emit $ Imp.SetScalar name $ Imp.var src pt

    (ScalarDestination name, ArrayVar _ arr) -> do
      let bt = entryArrayElemType arr
      (mem, space, i) <-
        fullyIndexArray' (entryArrayLocation arr) src_is bt
      vol <- asks envVolatility
      emit $ Imp.SetScalar name $ Imp.index mem i bt space vol

    (ArrayElemDestination{}, _) | not $ null dest_is->
      compilerBugS $
      unwords ["copyDWIMDest: array elemenent destination given indices:", pretty dest_is]

    (ArrayElemDestination dest_mem _ dest_space dest_i,
     ScalarVar _ (ScalarEntry bt)) -> do
      vol <- asks envVolatility
      emit $ Imp.Write dest_mem dest_i bt dest_space vol $ Imp.var src bt

    (ArrayElemDestination dest_mem _ dest_space dest_i, ArrayVar _ src_arr)
      | length (entryArrayShape src_arr) == length src_is -> do
          let bt = entryArrayElemType src_arr
          (src_mem, src_space, src_i) <-
            fullyIndexArray' (entryArrayLocation src_arr) src_is bt
          vol <- asks envVolatility
          emit $ Imp.Write dest_mem dest_i bt dest_space vol $
            Imp.index src_mem src_i bt src_space vol

    (ArrayElemDestination{}, ArrayVar{}) ->
      compilerBugS $
      unwords ["copyDWIMDest: array element destination, but array source",
               pretty src,
               "with incomplete indexing."]

    (ArrayDestination (Just dest_loc), ArrayVar _ src_arr) -> do
      let src_loc = entryArrayLocation src_arr
          bt = entryArrayElemType src_arr
      emit =<< copyArrayDWIM bt dest_loc dest_is src_loc src_is

    (ArrayDestination (Just dest_loc), ScalarVar _ (ScalarEntry bt)) -> do
      (dest_mem, dest_space, dest_i) <-
        fullyIndexArray' dest_loc dest_is bt
      vol <- asks envVolatility
      emit $ Imp.Write dest_mem dest_i bt dest_space vol (Imp.var src bt)

    (ArrayDestination Nothing, _) ->
      return () -- Nothing to do; something else set some memory
                -- somewhere.

-- | Copy from here to there; both destination and source be
-- indexeded.  If so, they better be arrays of enough dimensions.
-- This function will generally just Do What I Mean, and Do The Right
-- Thing.  Both destination and source must be in scope.
copyDWIM :: VName -> [Imp.Exp] -> SubExp -> [Imp.Exp]
         -> ImpM lore op ()
copyDWIM dest dest_is src src_is = do
  dest_entry <- lookupVar dest
  let dest_target =
        case dest_entry of
          ScalarVar _ _ ->
            ScalarDestination dest

          ArrayVar _ (ArrayEntry (MemLocation mem shape ixfun) _) ->
            ArrayDestination $ Just $ MemLocation mem shape ixfun

          MemVar _ _ ->
            MemoryDestination dest
  copyDWIMDest dest_target dest_is src src_is

-- | @compileAlloc dest size space@ allocates @n@ bytes of memory in @space@,
-- writing the result to @dest@, which must be a single
-- 'MemoryDestination',
compileAlloc :: Destination -> SubExp -> Space
             -> ImpM lore op ()
compileAlloc (Destination _ [MemoryDestination mem]) e space = do
  e' <- compileSubExp e
  emit $ Imp.Allocate mem (Imp.bytes e') space
compileAlloc dest _ _ =
  compilerBugS $ "compileAlloc: Invalid destination: " ++ show dest

dimSizeToSubExp :: Imp.Size -> SubExp
dimSizeToSubExp (Imp.ConstSize n) = constant n
dimSizeToSubExp (Imp.VarSize v) = Var v

dimSizeToExp :: Imp.Size -> Imp.Exp
dimSizeToExp = compilePrimExp . primExpFromSubExp int32 . dimSizeToSubExp