{-# LANGUAGE GeneralizedNewtypeDeriving, TypeFamilies, FlexibleContexts, TupleSections, FlexibleInstances, MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE DefaultSignatures #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Futhark.Pass.ExplicitAllocations
       ( explicitAllocations
       , explicitAllocationsInStms
       , simplifiable

       , arraySizeInBytesExp
       )
where

import Control.Monad.State
import Control.Monad.Writer
import Control.Monad.Reader
import Control.Monad.RWS.Strict
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import qualified Control.Monad.Fail as Fail
import Data.Maybe

import Futhark.Representation.Kernels
import Futhark.Optimise.Simplify.Lore
  (mkWiseBody, mkWiseLetStm, removeExpWisdom, removeScopeWisdom)
import Futhark.MonadFreshNames
import Futhark.Representation.ExplicitMemory
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Tools
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Optimise.Simplify.Engine (SimpleOps (..))
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Pass
import Futhark.Util (splitFromEnd, takeLast)

data AllocStm = SizeComputation VName (PrimExp VName)
              | Allocation VName SubExp Space
              | ArrayCopy VName VName
                    deriving (Eq, Ord, Show)

bindAllocStm :: (MonadBinder m, Op (Lore m) ~ MemOp inner) =>
                AllocStm -> m ()
bindAllocStm (SizeComputation name pe) =
  letBindNames_ [name] =<< toExp (coerceIntPrimExp Int64 pe)
bindAllocStm (Allocation name size space) =
  letBindNames_ [name] $ Op $ Alloc size space
bindAllocStm (ArrayCopy name src) =
  letBindNames_ [name] $ BasicOp $ Copy src

defaultExpHints :: (Monad m, Attributes lore) => Exp lore -> m [ExpHint]
defaultExpHints e = return $ replicate (expExtTypeSize e) NoHint

class (MonadFreshNames m, HasScope lore m, ExplicitMemorish lore) =>
      Allocator lore m where
  addAllocStm :: AllocStm -> m ()
  askDefaultSpace :: m Space

  default addAllocStm :: (Allocable fromlore lore,
                          Op lore ~ MemOp inner,
                          m ~ AllocM fromlore lore)
                      => AllocStm -> m ()
  addAllocStm (SizeComputation name se) =
    letBindNames_ [name] =<< toExp (coerceIntPrimExp Int64 se)
  addAllocStm (Allocation name size space) =
    letBindNames_ [name] $ Op $ Alloc size space
  addAllocStm (ArrayCopy name src) =
    letBindNames_ [name] $ BasicOp $ Copy src

  -- | The subexpression giving the number of elements we should
  -- allocate space for.  See 'ChunkMap' comment.
  dimAllocationSize :: SubExp -> m SubExp

  default dimAllocationSize :: m ~ AllocM fromlore lore
                               => SubExp -> m SubExp
  dimAllocationSize (Var v) =
    -- It is important to recurse here, as the substitution may itself
    -- be a chunk size.
    maybe (return $ Var v) dimAllocationSize =<< asks (M.lookup v . chunkMap)
  dimAllocationSize size =
    return size

  expHints :: Exp lore -> m [ExpHint]
  expHints = defaultExpHints

allocateMemory :: Allocator lore m =>
                  String -> SubExp -> Space -> m VName
allocateMemory desc size space = do
  v <- newVName desc
  addAllocStm $ Allocation v size space
  return v

computeSize :: Allocator lore m =>
               String -> PrimExp VName -> m SubExp
computeSize desc se = do
  v <- newVName desc
  addAllocStm $ SizeComputation v se
  return $ Var v

type Allocable fromlore tolore =
  (PrettyLore fromlore, PrettyLore tolore,
   ExplicitMemorish tolore,
   SameScope fromlore Kernels,
   RetType fromlore ~ RetType Kernels,
   BranchType fromlore ~ BranchType Kernels,
   BodyAttr fromlore ~ (),
   BodyAttr tolore ~ (),
   ExpAttr tolore ~ (),
   SizeSubst (Op tolore),
   BinderOps tolore)

-- | A mapping from chunk names to their maximum size.  XXX FIXME
-- HACK: This is part of a hack to add loop-invariant allocations to
-- reduce kernels, because memory expansion does not use range
-- analysis yet (it should).
type ChunkMap = M.Map VName SubExp

data AllocEnv fromlore tolore  =
  AllocEnv { chunkMap :: ChunkMap
           , aggressiveReuse :: Bool
             -- ^ Aggressively try to reuse memory in do-loops -
             -- should be True inside kernels, False outside.
           , allocSpace :: Space
             -- ^ When allocating memory, put it in this memory space.
             -- This is primarily used to ensure that group-wide
             -- statements store their results in local memory.
           , allocInOp :: Op fromlore -> AllocM fromlore tolore (Op tolore)
           , envExpHints :: Exp tolore -> AllocM fromlore tolore [ExpHint]
           }

boundDims :: ChunkMap -> AllocEnv fromlore tolore
          -> AllocEnv fromlore tolore
boundDims m env = env { chunkMap = m <> chunkMap env }

-- | Monad for adding allocations to an entire program.
newtype AllocM fromlore tolore a =
  AllocM (BinderT tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a)
  deriving (Applicative, Functor, Monad,
             MonadFreshNames,
             HasScope tolore,
             LocalScope tolore,
             MonadReader (AllocEnv fromlore tolore))

instance Fail.MonadFail (AllocM fromlore tolore) where
  fail = error . ("AllocM.fail: "++)

instance (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
         MonadBinder (AllocM fromlore tolore) where
  type Lore (AllocM fromlore tolore) = tolore

  mkExpAttrM _ _ = return ()

  mkLetNamesM names e = do
    pat <- patternWithAllocations names e
    return $ Let pat (defAux ()) e

  mkBodyM bnds res = return $ Body () bnds res

  addStms binding = AllocM $ addBinderStms binding
  collectStms (AllocM m) = AllocM $ collectBinderStms m
  certifying cs (AllocM m) = AllocM $ certifyingBinder cs m

instance Allocable fromlore ExplicitMemory =>
         Allocator ExplicitMemory (AllocM fromlore ExplicitMemory) where
  expHints e = do
    f <- asks envExpHints
    f e
  askDefaultSpace = asks allocSpace

runAllocM :: MonadFreshNames m =>
             (Op fromlore -> AllocM fromlore tolore (Op tolore))
          -> (Exp tolore -> AllocM fromlore tolore [ExpHint])
          -> AllocM fromlore tolore a -> m a
runAllocM handleOp hints (AllocM m) =
  fmap fst $ modifyNameSource $ runState $ runReaderT (runBinderT m mempty) env
  where env = AllocEnv mempty False DefaultSpace handleOp hints

-- | Monad for adding allocations to a single pattern.
newtype PatAllocM lore a = PatAllocM (RWS
                                      (Scope lore)
                                      [AllocStm]
                                      VNameSource
                                      a)
                    deriving (Applicative, Functor, Monad,
                              HasScope lore,
                              MonadWriter [AllocStm],
                              MonadFreshNames)

instance Allocator ExplicitMemory (PatAllocM ExplicitMemory) where
  addAllocStm = tell . pure
  dimAllocationSize = return
  askDefaultSpace = return DefaultSpace

runPatAllocM :: MonadFreshNames m =>
                PatAllocM lore a -> Scope lore
             -> m (a, [AllocStm])
runPatAllocM (PatAllocM m) mems =
  modifyNameSource $ frob . runRWS m mems
  where frob (a,s,w) = ((a,w),s)

arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp t =
  product
    [ toInt64 $ product $ map (primExpFromSubExp int32) (arrayDims t)
    , ValueExp $ IntValue $ Int64Value $ primByteSize $ elemType t ]
  where toInt64 = ConvOpExp $ SExt Int32 Int64

arraySizeInBytesExpM :: Allocator lore m => Type -> m (PrimExp VName)
arraySizeInBytesExpM t = do
  dims <- mapM dimAllocationSize (arrayDims t)
  let dim_prod_i32 = product $ map (toInt64 . primExpFromSubExp int32) dims
  let elm_size_i64 = ValueExp $ IntValue $ Int64Value $ primByteSize $ elemType t
  return $ product [ dim_prod_i32, elm_size_i64 ]
  where toInt64 = ConvOpExp $ SExt Int32 Int64

arraySizeInBytes :: Allocator lore m => Type -> m SubExp
arraySizeInBytes = computeSize "bytes" <=< arraySizeInBytesExpM

allocForArray :: Allocator lore m =>
                 Type -> Space -> m VName
allocForArray t space = do
  size <- arraySizeInBytes t
  allocateMemory "mem" size space

allocsForStm :: (Allocator lore m, ExpAttr lore ~ ()) =>
                [Ident] -> [Ident] -> Exp lore
             -> m (Stm lore, [AllocStm])
allocsForStm sizeidents validents e = do
  rts <- expReturns e
  hints <- expHints e
  (ctxElems, valElems, postbnds) <- allocsForPattern sizeidents validents rts hints
  return (Let (Pattern ctxElems valElems) (defAux ()) e,
          postbnds)

patternWithAllocations :: (Allocator lore m, ExpAttr lore ~ ()) =>
                          [VName]
                       -> Exp lore
                       -> m (Pattern lore)
patternWithAllocations names e = do
  (ts',sizes) <- instantiateShapes' =<< expExtType e
  let identForBindage name t =
        pure $ Ident name t
  vals <- sequence [ identForBindage name t | (name, t) <- zip names ts' ]
  (Let pat _ _, extrabnds) <- allocsForStm sizes vals e
  case extrabnds of
    [] -> return pat
    _  -> fail $ "Cannot make allocations for pattern of " ++ pretty e

allocsForPattern :: Allocator lore m =>
                    [Ident] -> [Ident] -> [ExpReturns] -> [ExpHint]
                 -> m ([PatElem ExplicitMemory],
                       [PatElem ExplicitMemory],
                       [AllocStm])
allocsForPattern sizeidents validents rts hints = do
  let sizes' = [ PatElem size $ MemPrim int32 | size <- map identName sizeidents ]
  (vals, (mems, postbnds)) <-
    runWriterT $ forM (zip3 validents rts hints) $ \(ident, rt, hint) -> do
      let shape = arrayShape $ identType ident
      case rt of
        MemPrim _ -> do
          summary <- lift $ summaryForBindage (identType ident) hint
          return $ PatElem (identName ident) summary

        MemMem space ->
          return $ PatElem (identName ident) $
          MemMem space

        MemArray bt _ u (Just (ReturnsInBlock mem ixfun)) ->
          PatElem (identName ident) . MemArray bt shape u .
          ArrayIn mem <$> instantiateIxFun ixfun

        MemArray _ extshape _ Nothing
          | Just _ <- knownShape extshape -> do
            summary <- lift $ summaryForBindage (identType ident) hint
            return $ PatElem (identName ident) summary

        MemArray bt _ u ret -> do
          space <- case ret of
                     Just (ReturnsNewBlock mem_space _ _) -> return mem_space
                     _                                    -> lift askDefaultSpace
          (mem,(ident',ixfun)) <- lift $ memForBindee ident space
          tell ([PatElem (identName mem) $ MemMem space],
                [])
          return $ PatElem (identName ident') $ MemArray bt shape u $
            ArrayIn (identName mem) ixfun

  return (sizes' <> mems,
          vals,
          postbnds)
  where knownShape = mapM known . shapeDims
        known (Free v) = Just v
        known Ext{} = Nothing

instantiateIxFun :: Monad m => ExtIxFun -> m IxFun
instantiateIxFun = traverse $ traverse inst
  where inst Ext{} = fail "instantiateIxFun: not yet"
        inst (Free x) = return x

summaryForBindage :: Allocator lore m =>
                     Type -> ExpHint
                  -> m (MemBound NoUniqueness)
summaryForBindage (Prim bt) _ =
  return $ MemPrim bt
summaryForBindage (Mem space) _ =
  return $ MemMem space
summaryForBindage t@(Array bt shape u) NoHint = do
  m <- allocForArray t =<< askDefaultSpace
  return $ directIndexFunction bt shape u m t
summaryForBindage t (Hint ixfun space) = do
  let bt = elemType t
  bytes <- computeSize "bytes" $
           product [ConvOpExp (SExt Int32 Int64) (product (IxFun.base ixfun)),
                    fromIntegral (primByteSize (elemType t)::Int64)]
  m <- allocateMemory "mem" bytes space
  return $ MemArray bt (arrayShape t) NoUniqueness $ ArrayIn m ixfun

memForBindee :: (MonadFreshNames m) =>
                Ident -> Space
             -> m (Ident,
                   (Ident, IxFun))
memForBindee ident space = do
  mem <- newIdent memname (Mem space)
  return (mem,
          (ident, IxFun.iota $ map (primExpFromSubExp int32) $ arrayDims t))
  where  memname = baseString (identName ident) <> "_mem"
         t       = identType ident

directIndexFunction :: PrimType -> Shape -> u -> VName -> Type -> MemBound u
directIndexFunction bt shape u mem t =
  MemArray bt shape u $ ArrayIn mem $
  IxFun.iota $ map (primExpFromSubExp int32) $ arrayDims t

allocInFParams :: (Allocable fromlore tolore) =>
                  [(FParam fromlore, Space)] ->
                  ([FParam tolore] -> AllocM fromlore tolore a)
               -> AllocM fromlore tolore a
allocInFParams params m = do
  (valparams, memparams) <-
    runWriterT $ mapM (uncurry allocInFParam) params
  let params' = memparams <> valparams
      summary = scopeOfFParams params'
  localScope summary $ m params'

allocInFParam :: (Allocable fromlore tolore) =>
                 FParam fromlore
              -> Space
              -> WriterT [FParam tolore]
                 (AllocM fromlore tolore) (FParam tolore)
allocInFParam param pspace =
  case paramDeclType param of
    Array bt shape u -> do
      let memname = baseString (paramName param) <> "_mem"
          ixfun = IxFun.iota $ map (primExpFromSubExp int32) $ shapeDims shape
      mem <- lift $ newVName memname
      tell [Param mem $ MemMem pspace]
      return param { paramAttr =  MemArray bt shape u $ ArrayIn mem ixfun }
    Prim bt ->
      return param { paramAttr = MemPrim bt }
    Mem space ->
      return param { paramAttr = MemMem space }

allocInMergeParams :: (Allocable fromlore tolore,
                       Allocator tolore (AllocM fromlore tolore)) =>
                      [VName]
                   -> [(FParam fromlore,SubExp)]
                   -> ([FParam tolore]
                       -> [FParam tolore]
                       -> ([SubExp] -> AllocM fromlore tolore ([SubExp], [SubExp]))
                       -> AllocM fromlore tolore a)
                   -> AllocM fromlore tolore a
allocInMergeParams variant merge m = do
  ((valparams, handle_loop_subexps), mem_params) <-
    runWriterT $ unzip <$> mapM allocInMergeParam merge
  let mergeparams' = mem_params <> valparams
      summary = scopeOfFParams mergeparams'

      mk_loop_res ses = do
        (valargs, memargs) <-
          runWriterT $ zipWithM ($) handle_loop_subexps ses
        return (memargs, valargs)

  localScope summary $ m mem_params valparams mk_loop_res
  where allocInMergeParam (mergeparam, Var v)
          | Array bt shape u <- paramDeclType mergeparam = do
              (mem, ixfun) <- lift $ lookupArraySummary v
              Mem space <- lift $ lookupType mem
              reuse <- asks aggressiveReuse
              if space /= Space "local" &&
                 reuse &&
                 u == Unique &&
                 loopInvariantShape mergeparam
                then return (mergeparam { paramAttr = MemArray bt shape Unique $ ArrayIn mem ixfun },
                             lift . ensureArrayIn (paramType mergeparam) mem ixfun)
                else do def_space <- asks allocSpace
                        doDefault mergeparam def_space

        allocInMergeParam (mergeparam, _) = doDefault mergeparam =<< lift askDefaultSpace

        doDefault mergeparam space = do
          mergeparam' <- allocInFParam mergeparam space
          return (mergeparam', linearFuncallArg (paramType mergeparam) space)

        variant_names = variant ++ map (paramName . fst) merge
        loopInvariantShape =
          not . any (`elem` variant_names) . subExpVars . arrayDims . paramType

ensureArrayIn :: (Allocable fromlore tolore,
                  Allocator tolore (AllocM fromlore tolore)) =>
                 Type -> VName -> IxFun -> SubExp
              -> AllocM fromlore tolore SubExp
ensureArrayIn _ _ _ (Constant v) =
  fail $ "ensureArrayIn: " ++ pretty v ++ " cannot be an array."
ensureArrayIn t mem ixfun (Var v) = do
  (src_mem, src_ixfun) <- lookupArraySummary v
  if src_mem == mem && src_ixfun == ixfun
    then return $ Var v
    else do copy <- newIdent (baseString v ++ "_ensure_copy") t
            let summary = MemArray (elemType t) (arrayShape t) NoUniqueness $
                          ArrayIn mem ixfun
                pat = Pattern [] [PatElem (identName copy) summary]
            letBind_ pat $ BasicOp $ Copy v
            return $ Var $ identName copy

ensureDirectArray :: (Allocable fromlore tolore,
                      Allocator tolore (AllocM fromlore tolore)) =>
                     Maybe Space -> VName -> AllocM fromlore tolore (VName, SubExp)
ensureDirectArray space_ok v = do
  (mem, ixfun) <- lookupArraySummary v
  Mem mem_space <- lookupType mem
  default_space <- askDefaultSpace
  if IxFun.isDirect ixfun && maybe True (==mem_space) space_ok
    then return (mem, Var v)
    else needCopy (fromMaybe default_space space_ok)
  where needCopy space =
          -- We need to do a new allocation, copy 'v', and make a new
          -- binding for the size of the memory block.
          allocLinearArray space (baseString v) v

allocLinearArray :: (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
                    Space -> String -> VName
                 -> AllocM fromlore tolore (VName, SubExp)
allocLinearArray space s v = do
  t <- lookupType v
  mem <- allocForArray t space
  v' <- newIdent (s ++ "_linear") t
  let pat = Pattern [] [PatElem (identName v') $
                        directIndexFunction (elemType t) (arrayShape t)
                        NoUniqueness mem t]
  addStm $ Let pat (defAux ()) $ BasicOp $ Copy v
  return (mem, Var $ identName v')

funcallArgs :: (Allocable fromlore tolore,
                Allocator tolore (AllocM fromlore tolore)) =>
               [(SubExp,Diet)] -> AllocM fromlore tolore [(SubExp,Diet)]
funcallArgs args = do
  (valargs, mem_and_size_args) <- runWriterT $ forM args $ \(arg,d) -> do
    t <- lift $ subExpType arg
    space <- lift askDefaultSpace
    arg' <- linearFuncallArg t space arg
    return (arg', d)
  return $ map (,Observe) mem_and_size_args <> valargs

linearFuncallArg :: (Allocable fromlore tolore,
                     Allocator tolore (AllocM fromlore tolore)) =>
                    Type -> Space -> SubExp
                 -> WriterT [SubExp] (AllocM fromlore tolore) SubExp
linearFuncallArg Array{} space (Var v) = do
  (mem, arg') <- lift $ ensureDirectArray (Just space) v
  tell [Var mem]
  return arg'
linearFuncallArg _ _ arg =
  return arg

explicitAllocations :: Pass Kernels ExplicitMemory
explicitAllocations =
  Pass "explicit allocations" "Transform program to explicit memory representation" $
  intraproceduralTransformation allocInFun

explicitAllocationsInStms :: (MonadFreshNames m, HasScope ExplicitMemory m) =>
                             Stms Kernels -> m (Stms ExplicitMemory)
explicitAllocationsInStms stms = do
  scope <- askScope
  runAllocM handleHostOp kernelExpHints $ localScope scope $ allocInStms stms return

memoryInRetType :: [RetType Kernels] -> [RetType ExplicitMemory]
memoryInRetType ts = evalState (mapM addAttr ts) $ startOfFreeIDRange ts
  where addAttr (Prim t) = return $ MemPrim t
        addAttr Mem{} = fail "memoryInRetType: too much memory"
        addAttr (Array bt shape u) = do
          i <- get <* modify (+1)
          return $ MemArray bt shape u $ ReturnsNewBlock DefaultSpace i $
            IxFun.iota $ map convert $ shapeDims shape

        convert (Ext i) = LeafExp (Ext i) int32
        convert (Free v) = Free <$> primExpFromSubExp int32 v

startOfFreeIDRange :: [TypeBase ExtShape u] -> Int
startOfFreeIDRange = S.size . shapeContext

allocInFun :: MonadFreshNames m => FunDef Kernels -> m (FunDef ExplicitMemory)
allocInFun (FunDef entry fname rettype params fbody) =
  runAllocM handleHostOp kernelExpHints $
  allocInFParams (zip params $ repeat DefaultSpace) $ \params' -> do
    fbody' <- insertStmsM $ allocInFunBody
              (map (const $ Just DefaultSpace) rettype) fbody
    return $ FunDef entry fname (memoryInRetType rettype) params' fbody'

handleHostOp :: HostOp Kernels (SOAC Kernels)
             -> AllocM Kernels ExplicitMemory (MemOp (HostOp ExplicitMemory ()))
handleHostOp (SplitSpace o w i elems_per_thread) =
  return $ Inner $ SplitSpace o w i elems_per_thread
handleHostOp (GetSize key size_class) =
  return $ Inner $ GetSize key size_class
handleHostOp (GetSizeMax size_class) =
  return $ Inner $ GetSizeMax size_class
handleHostOp (CmpSizeLe key size_class x) =
  return $ Inner $ CmpSizeLe key size_class x
handleHostOp (OtherOp op) =
  fail $ "Cannot allocate memory in SOAC: " ++ pretty op

handleHostOp (SegOp op) =
  Inner . SegOp <$> handleSegOp op

handleSegOp :: SegOp Kernels
            -> AllocM Kernels ExplicitMemory (SegOp ExplicitMemory)
handleSegOp op = allocAtLevel (segLevel op) $ mapSegOpM mapper op
  where scope = scopeOfSegSpace $ segSpace op
        mapper = identitySegOpMapper
             { mapOnSegOpBody = localScope scope . allocInKernelBody (segLevel op)
             , mapOnSegOpLambda = allocInBinOpLambda (segLevel op) $ segSpace op
             }

allocAtLevel :: SegLevel -> AllocM fromlore tlore a -> AllocM fromlore tlore a
allocAtLevel lvl = local $ \env -> env { allocSpace = space
                                       , aggressiveReuse = True
                                       }
  where space = case lvl of SegThread{} -> DefaultSpace
                            SegThreadScalar{} -> DefaultSpace
                            SegGroup{} -> Space "local"

bodyReturnMemCtx :: (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
                    SubExp -> AllocM fromlore tolore [SubExp]
bodyReturnMemCtx Constant{} =
  return []
bodyReturnMemCtx (Var v) = do
  info <- lookupMemInfo v
  case info of
    MemPrim{} -> return []
    MemMem{} -> return [] -- should not happen
    MemArray _ _ _ (ArrayIn mem _) -> return [Var mem]

allocInFunBody :: (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
                  [Maybe Space] -> Body fromlore -> AllocM fromlore tolore (Body tolore)
allocInFunBody space_oks (Body _ bnds res) =
  allocInStms bnds $ \bnds' -> do
    (res'', allocs) <- collectStms $ do
      res' <- zipWithM ensureDirect space_oks' res
      let (ctx_res, val_res) = splitFromEnd num_vals res'
      mem_ctx_res <- concat <$> mapM bodyReturnMemCtx val_res
      return $ ctx_res <> mem_ctx_res <> val_res
    return $ Body () (bnds'<>allocs) res''
  where num_vals = length space_oks
        space_oks' = replicate (length res - num_vals) Nothing ++ space_oks
        ensureDirect _ se@Constant{} = return se
        ensureDirect space_ok (Var v) = do
          bt <- primType <$> lookupType v
          if bt
            then return $ Var v
            else do (_, v') <- ensureDirectArray space_ok v
                    return v'

allocInStms :: (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
               Stms fromlore -> (Stms tolore -> AllocM fromlore tolore a)
            -> AllocM fromlore tolore a
allocInStms origbnds m = allocInStms' (stmsToList origbnds) mempty
  where allocInStms' [] bnds' =
          m bnds'
        allocInStms' (x:xs) bnds' = do
          allocbnds <- allocInStm' x
          let summaries = scopeOf allocbnds
          localScope summaries $
            local (boundDims $ mconcat $ map sizeSubst $ stmsToList allocbnds) $
            allocInStms' xs (bnds'<>allocbnds)
        allocInStm' bnd = do
          ((),bnds') <- collectStms $ certifying (stmCerts bnd) $ allocInStm bnd
          return bnds'

allocInStm :: (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
              Stm fromlore -> AllocM fromlore tolore ()
allocInStm (Let (Pattern sizeElems valElems) _ e) = do
  e' <- allocInExp e
  let sizeidents = map patElemIdent sizeElems
      validents = map patElemIdent valElems
  (bnd, bnds) <- allocsForStm sizeidents validents e'
  addStm bnd
  mapM_ addAllocStm bnds

allocInExp :: (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
              Exp fromlore -> AllocM fromlore tolore (Exp tolore)
allocInExp (DoLoop ctx val form (Body () bodybnds bodyres)) =
  allocInMergeParams mempty ctx $ \_ ctxparams' _ ->
  allocInMergeParams (map paramName ctxparams') val $
  \new_ctx_params valparams' mk_loop_val -> do
  form' <- allocInLoopForm form
  localScope (scopeOf form') $ do
    (valinit_ctx, valinit') <- mk_loop_val valinit
    body' <- insertStmsM $ allocInStms bodybnds $ \bodybnds' -> do
      ((val_ses,valres'),val_retbnds) <- collectStms $ mk_loop_val valres
      return $ Body () (bodybnds'<>val_retbnds) (ctxres++val_ses++valres')
    return $
      DoLoop
      (zip (ctxparams'++new_ctx_params) (ctxinit++valinit_ctx))
      (zip valparams' valinit')
      form' body'
  where (_ctxparams, ctxinit) = unzip ctx
        (_valparams, valinit) = unzip val
        (ctxres, valres) = splitAt (length ctx) bodyres
allocInExp (Apply fname args rettype loc) = do
  args' <- funcallArgs args
  return $ Apply fname args' (memoryInRetType rettype) loc
allocInExp (If cond tbranch fbranch (IfAttr rets ifsort)) = do
  tbranch' <- allocInFunBody (map (const Nothing) rets) tbranch
  space_oks <- mkSpaceOks (length rets) tbranch'
  fbranch' <- allocInFunBody space_oks fbranch
  let rets' = createBodyReturns rets space_oks
      res_then = bodyResult tbranch'
      res_else = bodyResult fbranch'
      size_ext = length res_then - length rets'
      (ind_ses0, r_then_else) =
        foldl (\(acc_ise,acc_ext) (r_then, r_else, i) ->
                if r_then == r_else then ((i,r_then):acc_ise, acc_ext)
                else (acc_ise, (r_then, r_else):acc_ext)
              ) ([],[]) $ reverse $ zip3 res_then res_else [0..size_ext-1]
      (r_then_ext, r_else_ext) = unzip r_then_else
      ind_ses = zipWith (\(i,se) k -> (i-k,se)) ind_ses0 [0..length ind_ses0 - 1]
      rets'' = foldl (\acc (i,se) -> fixExt i se acc) rets' ind_ses
      tbranch'' = tbranch' { bodyResult = r_then_ext ++ drop size_ext res_then }
      fbranch'' = fbranch' { bodyResult = r_else_ext ++ drop size_ext res_else }
  return $ If cond tbranch'' fbranch'' $ IfAttr rets'' ifsort
allocInExp e = mapExpM alloc e
  where alloc =
          identityMapper { mapOnBody = fail "Unhandled Body in ExplicitAllocations"
                         , mapOnRetType = fail "Unhandled RetType in ExplicitAllocations"
                         , mapOnBranchType = fail "Unhandled BranchType in ExplicitAllocations"
                         , mapOnFParam = fail "Unhandled FParam in ExplicitAllocations"
                         , mapOnLParam = fail "Unhandled LParam in ExplicitAllocations"
                         , mapOnOp = \op -> do handle <- asks allocInOp
                                               handle op
                         }

mkSpaceOks :: (ExplicitMemorish tolore, LocalScope tolore m) =>
              Int -> Body tolore -> m [Maybe Space]
mkSpaceOks num_vals (Body _ stms res) =
  inScopeOf stms $
  mapM mkSpaceOK $ takeLast num_vals res
  where mkSpaceOK (Var v) = do
          v_info <- lookupMemInfo v
          case v_info of MemArray _ _ _ (ArrayIn mem _) -> do
                           mem_info <- lookupMemInfo mem
                           case mem_info of MemMem space -> return $ Just space
                                            _ -> return Nothing
                         _ -> return Nothing
        mkSpaceOK _ = return Nothing

createBodyReturns :: [ExtType] -> [Maybe Space] -> [BodyReturns]
createBodyReturns ts spaces =
  evalState (zipWithM inspect ts spaces) $ S.size $ shapeContext ts
  where inspect (Array pt shape u) space = do
          i <- get <* modify (+1)
          let space' = fromMaybe DefaultSpace space
          return $ MemArray pt shape u $ ReturnsNewBlock space' i $
            IxFun.iota $ map convert $ shapeDims shape
        inspect (Prim pt) _ =
          return $ MemPrim pt
        inspect (Mem space) _ =
          return $ MemMem space

        convert (Ext i) = LeafExp (Ext i) int32
        convert (Free v) = Free <$> primExpFromSubExp int32 v

allocInLoopForm :: (Allocable fromlore tolore,
                    Allocator tolore (AllocM fromlore tolore)) =>
                   LoopForm fromlore -> AllocM fromlore tolore (LoopForm tolore)
allocInLoopForm (WhileLoop v) = return $ WhileLoop v
allocInLoopForm (ForLoop i it n loopvars) =
  ForLoop i it n <$> mapM allocInLoopVar loopvars
  where allocInLoopVar (p,a) = do
          (mem, ixfun) <- lookupArraySummary a
          case paramType p of
            Array bt shape u ->
              let ixfun' = IxFun.slice ixfun $
                           fullSliceNum (IxFun.shape ixfun) [DimFix $ LeafExp i int32]
              in return (p { paramAttr = MemArray bt shape u $ ArrayIn mem ixfun' }, a)
            Prim bt ->
              return (p { paramAttr = MemPrim bt }, a)
            Mem space ->
              return (p { paramAttr = MemMem space }, a)

allocInBinOpLambda :: SegLevel -> SegSpace -> Lambda Kernels
                   -> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory)
allocInBinOpLambda lvl (SegSpace flat _) lam = do
  num_threads <- letSubExp "num_threads" $ BasicOp $ BinOp (Mul Int32)
                 (unCount (segNumGroups lvl)) (unCount (segGroupSize lvl))
  let (acc_params, arr_params) =
        splitAt (length (lambdaParams lam) `div` 2) $ lambdaParams lam
      index_x = LeafExp flat int32
      index_y = index_x + primExpFromSubExp int32 num_threads
  (acc_params', arr_params') <-
    allocInBinOpParams num_threads index_x index_y acc_params arr_params

  local (\env -> env { envExpHints = inThreadExpHints }) $
    allocInLambda (acc_params' ++ arr_params')
    (lambdaBody lam) (lambdaReturnType lam)

allocInBinOpParams :: SubExp
                   -> PrimExp VName -> PrimExp VName
                   -> [LParam Kernels]
                   -> [LParam Kernels]
                   -> AllocM Kernels ExplicitMemory ([LParam ExplicitMemory], [LParam ExplicitMemory])
allocInBinOpParams num_threads my_id other_id xs ys = unzip <$> zipWithM alloc xs ys
  where alloc x y =
          case paramType x of
            Array bt shape u -> do
              twice_num_threads <-
                letSubExp "twice_num_threads" $
                BasicOp $ BinOp (Mul Int32) num_threads $ intConst Int32 2
              let t = paramType x `arrayOfRow` twice_num_threads
              mem <- allocForArray t DefaultSpace
              -- XXX: this iota ixfun is a bit inefficient; leading to uncoalesced access.
              let ixfun_base = IxFun.iota $
                               map (primExpFromSubExp int32) (arrayDims t)
                  ixfun_x = IxFun.slice ixfun_base $
                            fullSliceNum (IxFun.shape ixfun_base) [DimFix my_id]
                  ixfun_y = IxFun.slice ixfun_base $
                            fullSliceNum (IxFun.shape ixfun_base) [DimFix other_id]
              return (x { paramAttr = MemArray bt shape u $ ArrayIn mem ixfun_x },
                      y { paramAttr = MemArray bt shape u $ ArrayIn mem ixfun_y })
            Prim bt ->
              return (x { paramAttr = MemPrim bt },
                      y { paramAttr = MemPrim bt })
            Mem space ->
              return (x { paramAttr = MemMem space },
                      y { paramAttr = MemMem space })

allocInLambda :: [LParam ExplicitMemory] -> Body Kernels -> [Type]
              -> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory)
allocInLambda params body rettype = do
  body' <- localScope (scopeOfLParams params) $
           allocInStms (bodyStms body) $ \bnds' ->
           return $ Body () bnds' $ bodyResult body
  return $ Lambda params body' rettype

allocInKernelBody :: SegLevel -> KernelBody Kernels
                  -> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory)
allocInKernelBody lvl (KernelBody () stms res) =
  local f $ allocInStms stms $ \stms' -> return $ KernelBody () stms' res
  where f = case lvl of SegThread{} -> inThread
                        SegThreadScalar{} -> inThread
                        SegGroup{} -> inGroup
        inThread env = env { envExpHints = inThreadExpHints }
        inGroup env = env { envExpHints = inGroupExpHints }

class SizeSubst op where
  opSizeSubst :: PatternT attr -> op -> ChunkMap

instance SizeSubst (HostOp lore op) where
  opSizeSubst (Pattern _ [size]) (SplitSpace _ _ _ elems_per_thread) =
    M.singleton (patElemName size) elems_per_thread
  opSizeSubst _ _ = mempty

instance SizeSubst op => SizeSubst (MemOp op) where
  opSizeSubst pat (Inner op) = opSizeSubst pat op
  opSizeSubst _ _ = mempty

sizeSubst :: SizeSubst (Op lore) => Stm lore -> ChunkMap
sizeSubst (Let pat _ (Op op)) = opSizeSubst pat op
sizeSubst _ = mempty

mkLetNamesB' :: (Op (Lore m) ~ MemOp inner,
                 MonadBinder m, ExpAttr (Lore m) ~ (),
                 Allocator (Lore m) (PatAllocM (Lore m))) =>
                ExpAttr (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesB' attr names e = do
  scope <- askScope
  pat <- bindPatternWithAllocations scope names e
  return $ Let pat (defAux attr) e

mkLetNamesB'' :: (Op (Lore m) ~ MemOp inner, ExpAttr lore ~ (),
                   HasScope (Engine.Wise lore) m, Allocator lore (PatAllocM lore),
                   MonadBinder m, Engine.CanBeWise (Op lore)) =>
                 [VName] -> Exp (Engine.Wise lore)
              -> m (Stm (Engine.Wise lore))
mkLetNamesB'' names e = do
  scope <- Engine.removeScopeWisdom <$> askScope
  (pat, prestms) <- runPatAllocM (patternWithAllocations names $ Engine.removeExpWisdom e) scope
  mapM_ bindAllocStm prestms
  let pat' = Engine.addWisdomToPattern pat e
      attr = Engine.mkWiseExpAttr pat' () e
  return $ Let pat' (defAux attr) e

instance BinderOps ExplicitMemory where
  mkExpAttrB _ _ = return ()
  mkBodyB stms res = return $ Body () stms res
  mkLetNamesB = mkLetNamesB' ()

instance BinderOps (Engine.Wise ExplicitMemory) where
  mkExpAttrB pat e = return $ Engine.mkWiseExpAttr pat () e
  mkBodyB stms res = return $ Engine.mkWiseBody () stms res
  mkLetNamesB = mkLetNamesB''

simplifiable :: (Engine.SimplifiableLore lore,
                 ExpAttr lore ~ (),
                 BodyAttr lore ~ (),
                 Op lore ~ MemOp inner,
                 Allocator lore (PatAllocM lore)) =>
                (inner -> Engine.SimpleM lore (Engine.OpWithWisdom inner, Stms (Engine.Wise lore)))
             -> SimpleOps lore
simplifiable simplifyInnerOp =
  SimpleOps mkExpAttrS' mkBodyS' mkLetNamesS' simplifyOp
  where mkExpAttrS' _ pat e =
          return $ Engine.mkWiseExpAttr pat () e

        mkBodyS' _ bnds res = return $ mkWiseBody () bnds res

        mkLetNamesS' vtable names e = do
          (pat', stms) <- runBinder $ bindPatternWithAllocations env names $
                          removeExpWisdom e
          return (mkWiseLetStm pat' (defAux ()) e, stms)
          where env = removeScopeWisdom $ ST.toScope vtable

        simplifyOp (Alloc size space) =
          (,) <$> (Alloc <$> Engine.simplify size <*> pure space) <*> pure mempty
        simplifyOp (Inner k) = do (k', hoisted) <- simplifyInnerOp k
                                  return (Inner k', hoisted)

bindPatternWithAllocations :: (MonadBinder m,
                               ExpAttr lore ~ (),
                               Op (Lore m) ~ MemOp inner,
                               Allocator lore (PatAllocM lore)) =>
                              Scope lore -> [VName] -> Exp lore
                           -> m (Pattern lore)
bindPatternWithAllocations types names e = do
  (pat,prebnds) <- runPatAllocM (patternWithAllocations names e) types
  mapM_ bindAllocStm prebnds
  return pat

data ExpHint = NoHint
             | Hint IxFun Space

kernelExpHints :: Allocator ExplicitMemory m => Exp ExplicitMemory -> m [ExpHint]
kernelExpHints (BasicOp (Manifest perm v)) = do
  dims <- arrayDims <$> lookupType v
  let perm_inv = rearrangeInverse perm
      dims' = rearrangeShape perm dims
      ixfun = IxFun.permute (IxFun.iota $ map (primExpFromSubExp int32) dims')
              perm_inv
  return [Hint ixfun DefaultSpace]

kernelExpHints (Op (Inner (SegOp (SegMap lvl@SegThread{} space ts body)))) =
  zipWithM (mapResultHint lvl space) ts $ kernelBodyResult body

kernelExpHints (Op (Inner (SegOp (SegRed lvl@SegThread{} space reds ts body)))) =
  (map (const NoHint) red_res <>) <$> zipWithM (mapResultHint lvl space) (drop num_reds ts) map_res
  where num_reds = segRedResults reds
        (red_res, map_res) = splitAt num_reds $ kernelBodyResult body

kernelExpHints e =
  return $ replicate (expExtTypeSize e) NoHint

mapResultHint :: Allocator lore m =>
                 SegLevel -> SegSpace -> Type -> KernelResult -> m ExpHint
mapResultHint lvl space = hint
  where num_threads = primExpFromSubExp int32 (unCount $ segNumGroups lvl) *
                      primExpFromSubExp int32 (unCount $ segGroupSize lvl)

        -- Heuristic: do not rearrange for returned arrays that are
        -- sufficiently small.
        coalesceReturnOfShape _ [] = False
        coalesceReturnOfShape bs [Constant (IntValue (Int32Value d))] = bs * d > 4
        coalesceReturnOfShape _ _ = True

        hint t (Returns _)
          | coalesceReturnOfShape (primByteSize (elemType t)) $ arrayDims t = do
              let space_dims = segSpaceDims space
              t_dims <- mapM dimAllocationSize $ arrayDims t
              return $ Hint (innermost space_dims t_dims) DefaultSpace

        hint t (ConcatReturns SplitStrided{} w _ _) = do
          t_dims <- mapM dimAllocationSize $ arrayDims t
          return $ Hint (innermost [w] t_dims) DefaultSpace

        hint Prim{} (ConcatReturns SplitContiguous w elems_per_thread _) = do
          let ixfun_base = IxFun.iota [num_threads, primExpFromSubExp int32 elems_per_thread]
              ixfun_tr = IxFun.permute ixfun_base [1,0]
              ixfun = IxFun.reshape ixfun_tr $ map (DimNew . primExpFromSubExp int32) [w]
          return $ Hint ixfun DefaultSpace

        hint _ _ = return NoHint

innermost :: [SubExp] -> [SubExp] -> IxFun
innermost space_dims t_dims =
  let r = length t_dims
      dims = space_dims ++ t_dims
      perm = [length space_dims..length space_dims+r-1] ++
             [0..length space_dims-1]
      perm_inv = rearrangeInverse perm
      dims_perm = rearrangeShape perm dims
      ixfun_base = IxFun.iota $ map (primExpFromSubExp int32) dims_perm
      ixfun_rearranged = IxFun.permute ixfun_base perm_inv
  in ixfun_rearranged

inGroupExpHints :: Allocator ExplicitMemory m => Exp ExplicitMemory -> m [ExpHint]
inGroupExpHints (Op (Inner (SegOp (SegMap SegThreadScalar{} space ts _)))) = return $ do
  t <- ts
  case t of
    Prim pt ->
      return $ Hint (IxFun.iota $ map (primExpFromSubExp int32) $
                     segSpaceDims space ++ arrayDims t) $ Space $ scalarMemory pt
    _ ->
      return NoHint
inGroupExpHints e = return $ replicate (expExtTypeSize e) NoHint

inThreadExpHints :: Allocator ExplicitMemory m => Exp ExplicitMemory -> m [ExpHint]
inThreadExpHints e =
  mapM maybePrivate =<< expExtType e
  where maybePrivate t
          | arrayRank t > 0,
            Just t' <- hasStaticShape t,
            all semiStatic $ arrayDims t' = do
              alloc_dims <- mapM dimAllocationSize $ arrayDims t'
              let ixfun = IxFun.iota $ map (primExpFromSubExp int32) alloc_dims
              return $ Hint ixfun $ Space "private"
          | otherwise =
              return NoHint

        semiStatic Constant{} = True
        semiStatic _ = False