{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.CodeGen.ImpGen.Kernels
  ( compileProg
  )
  where

import Control.Arrow ((&&&))
import Control.Monad.Except
import Control.Monad.Reader
import Data.Maybe
import Data.Semigroup ((<>))
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.List

import Prelude hiding (quot)

import Futhark.Error
import Futhark.MonadFreshNames
import Futhark.Transform.Rename
import Futhark.Representation.ExplicitMemory
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpCode.Kernels (bytes)
import qualified Futhark.CodeGen.ImpGen as ImpGen
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.CodeGen.SetDefaultSpace
import Futhark.Tools (partitionChunkedKernelLambdaParameters, fullSliceNum)
import Futhark.Util.IntegralExp (quotRoundingUp, quot, rem, IntegralExp)
import Futhark.Util (splitAt3)

type CallKernelGen = ImpGen.ImpM ExplicitMemory Imp.HostOp
type InKernelGen = ImpGen.ImpM InKernel Imp.KernelOp

callKernelOperations :: ImpGen.Operations ExplicitMemory Imp.HostOp
callKernelOperations =
  ImpGen.Operations { ImpGen.opsExpCompiler = expCompiler
                    , ImpGen.opsCopyCompiler = callKernelCopy
                    , ImpGen.opsOpCompiler = opCompiler
                    , ImpGen.opsBodyCompiler = ImpGen.defCompileBody
                    }

inKernelOperations :: KernelConstants -> ImpGen.Operations InKernel Imp.KernelOp
inKernelOperations constants = (ImpGen.defaultOperations $ compileInKernelOp constants)
                               { ImpGen.opsCopyCompiler = inKernelCopy
                               , ImpGen.opsExpCompiler = inKernelExpCompiler
                               , ImpGen.opsBodyCompiler = compileNestedKernelBody constants
                               }

compileProg :: MonadFreshNames m => Prog ExplicitMemory -> m (Either InternalError Imp.Program)
compileProg prog =
  fmap (setDefaultSpace (Imp.Space "device")) <$>
  ImpGen.compileProg callKernelOperations (Imp.Space "device") prog

opCompiler :: ImpGen.Destination -> Op ExplicitMemory
           -> CallKernelGen ()
opCompiler dest (Alloc e space) =
  ImpGen.compileAlloc dest e space
opCompiler dest (Inner kernel) =
  kernelCompiler dest kernel

compileInKernelOp :: KernelConstants -> ImpGen.Destination -> Op InKernel
                  -> InKernelGen ()
compileInKernelOp _ (ImpGen.Destination _ [ImpGen.MemoryDestination mem]) Alloc{} =
  compilerLimitationS $ "Cannot allocate memory block " ++ pretty mem ++ " in kernel."
compileInKernelOp _ dest Alloc{} =
  compilerBugS $ "Invalid target for in-kernel allocation: " ++ show dest
compileInKernelOp constants dest (Inner op) =
  compileKernelExp constants dest op

-- | Recognise kernels (maps), give everything else back.
kernelCompiler :: ImpGen.Destination -> Kernel InKernel
               -> CallKernelGen ()

kernelCompiler dest (GetSize key size_class) = do
  [v] <- ImpGen.funcallTargets dest
  ImpGen.emit $ Imp.Op $ Imp.GetSize v key size_class

kernelCompiler dest (CmpSizeLe key size_class x) = do
  [v] <- ImpGen.funcallTargets dest
  ImpGen.emit =<< Imp.Op . Imp.CmpSizeLe v key size_class <$> ImpGen.compileSubExp x

kernelCompiler dest (GetSizeMax size_class) = do
  [v] <- ImpGen.funcallTargets dest
  ImpGen.emit $ Imp.Op $ Imp.GetSizeMax v size_class

kernelCompiler dest (Kernel desc space _ kernel_body) = do

  num_groups' <- ImpGen.subExpToDimSize $ spaceNumGroups space
  group_size' <- ImpGen.subExpToDimSize $ spaceGroupSize space
  num_threads' <- ImpGen.subExpToDimSize $ spaceNumThreads space

  let bound_in_kernel =
        M.keys $
        scopeOfKernelSpace space <>
        scopeOf (kernelBodyStms kernel_body)

  let global_tid = spaceGlobalId space
      local_tid = spaceLocalId space
      group_id = spaceGroupId space
  wave_size <- newVName "wave_size"
  inner_group_size <- newVName "group_size"
  thread_active <- newVName "thread_active"

  let (space_is, space_dims) = unzip $ spaceDimensions space
  space_dims' <- mapM ImpGen.compileSubExp space_dims
  let constants = KernelConstants global_tid local_tid group_id
                  group_size' num_threads'
                  (Imp.VarSize wave_size) (zip space_is space_dims')
                  (Imp.var thread_active Bool) mempty

  kernel_body' <-
    makeAllMemoryGlobal $
    ImpGen.subImpM_ (inKernelOperations constants) $
    ImpGen.declaringPrimVar wave_size int32 $
    ImpGen.declaringPrimVar inner_group_size int32 $
    ImpGen.declaringPrimVar thread_active Bool $
    ImpGen.declaringScope Nothing (scopeOfKernelSpace space) $ do

    ImpGen.emit $
      Imp.Op (Imp.GetGlobalId global_tid 0) <>
      Imp.Op (Imp.GetLocalId local_tid 0) <>
      Imp.Op (Imp.GetLocalSize inner_group_size 0) <>
      Imp.Op (Imp.GetLockstepWidth wave_size) <>
      Imp.Op (Imp.GetGroupId group_id 0)

    setSpaceIndices space

    ImpGen.emit $ Imp.SetScalar thread_active (isActive $ spaceDimensions space)

    compileKernelBody dest constants kernel_body

  (uses, local_memory) <- computeKernelUses kernel_body' bound_in_kernel

  forM_ (kernelHints desc) $ \(s,v) -> do
    ty <- case v of
      Constant pv -> return $ Prim $ primValueType pv
      Var vn -> lookupType vn
    unless (primType ty) $ fail $ concat [ "debugKernelHint '", s, "'"
                                         , " in kernel '", kernelName desc, "'"
                                         , " did not have primType value." ]

    ImpGen.compileSubExp v >>= ImpGen.emit . Imp.DebugPrint s (elemType ty)

  ImpGen.emit $ Imp.Op $ Imp.CallKernel $ Imp.AnyKernel Imp.Kernel
            { Imp.kernelBody = kernel_body'
            , Imp.kernelLocalMemory = local_memory
            , Imp.kernelUses = uses
            , Imp.kernelNumGroups = num_groups'
            , Imp.kernelGroupSize = group_size'
            , Imp.kernelName = global_tid
            , Imp.kernelDesc = kernelName desc
            }

expCompiler :: ImpGen.ExpCompiler ExplicitMemory Imp.HostOp
-- We generate a simple kernel for itoa and replicate.
expCompiler
  (ImpGen.Destination tag [ImpGen.ArrayDestination (Just destloc)])
  (BasicOp (Iota n x s et)) = do
  thread_gid <- maybe (newVName "thread_gid") (return . VName (nameFromString "thread_gid")) tag

  makeAllMemoryGlobal $ do
    (destmem, destspace, destidx) <-
      ImpGen.fullyIndexArray' destloc [ImpGen.varIndex thread_gid] (IntType et)

    n' <- ImpGen.compileSubExp n
    x' <- ImpGen.compileSubExp x
    s' <- ImpGen.compileSubExp s

    let body = Imp.Write destmem destidx (IntType et) destspace Imp.Nonvolatile $
               Imp.ConvOpExp (SExt Int32 et) (Imp.var thread_gid int32) * s' + x'

    (group_size, num_groups) <- computeMapKernelGroups n'

    (body_uses, _) <- computeKernelUses
                      (freeIn body <> freeIn [n',x',s'])
                      [thread_gid]

    ImpGen.emit $ Imp.Op $ Imp.CallKernel $ Imp.Map Imp.MapKernel
      { Imp.mapKernelThreadNum = thread_gid
      , Imp.mapKernelDesc = "iota"
      , Imp.mapKernelNumGroups = Imp.VarSize num_groups
      , Imp.mapKernelGroupSize = Imp.VarSize group_size
      , Imp.mapKernelSize = n'
      , Imp.mapKernelUses = body_uses
      , Imp.mapKernelBody = body
      }

expCompiler
  (ImpGen.Destination tag [dest]) (BasicOp (Replicate (Shape ds) se)) = do
  constants <- simpleKernelConstants tag "replicate"

  t <- subExpType se
  let thread_gid = kernelGlobalThreadId constants
      row_dims = arrayDims t
      dims = ds ++ row_dims
      is' = unflattenIndex (map (ImpGen.compileSubExpOfType int32) dims) $
            ImpGen.varIndex thread_gid
  ds' <- mapM ImpGen.compileSubExp ds

  makeAllMemoryGlobal $ do
    body <- ImpGen.subImpM_ (inKernelOperations constants) $
      ImpGen.copyDWIMDest dest is' se $ drop (length ds) is'

    dims' <- mapM ImpGen.compileSubExp dims
    (group_size, num_groups) <- computeMapKernelGroups $ product dims'

    (body_uses, _) <- computeKernelUses
                      (freeIn body <> freeIn ds')
                      [thread_gid]

    ImpGen.emit $ Imp.Op $ Imp.CallKernel $ Imp.Map Imp.MapKernel
      { Imp.mapKernelThreadNum = thread_gid
      , Imp.mapKernelDesc = "replicate"
      , Imp.mapKernelNumGroups = Imp.VarSize num_groups
      , Imp.mapKernelGroupSize = Imp.VarSize group_size
      , Imp.mapKernelSize = product dims'
      , Imp.mapKernelUses = body_uses
      , Imp.mapKernelBody = body
      }

-- Allocation in the "local" space is just a placeholder.
expCompiler _ (Op (Alloc _ (Space "local"))) =
  return ()

expCompiler dest e =
  ImpGen.defCompileExp dest e

callKernelCopy :: ImpGen.CopyCompiler ExplicitMemory Imp.HostOp
callKernelCopy bt
  destloc@(ImpGen.MemLocation destmem destshape destIxFun)
  srcloc@(ImpGen.MemLocation srcmem srcshape srcIxFun)
  n
  | Just (destoffset, srcoffset,
          num_arrays, size_x, size_y,
          src_elems, dest_elems) <- isMapTransposeKernel bt destloc srcloc =
  ImpGen.emit $ Imp.Op $ Imp.CallKernel $
  Imp.MapTranspose bt
  destmem destoffset
  srcmem srcoffset
  num_arrays size_x size_y
  src_elems dest_elems

  | bt_size <- primByteSize bt,
    ixFunMatchesInnerShape
      (Shape $ map Imp.sizeToExp destshape) destIxFun,
    ixFunMatchesInnerShape
      (Shape $ map Imp.sizeToExp srcshape) srcIxFun,
    Just destoffset <-
      IxFun.linearWithOffset destIxFun bt_size,
    Just srcoffset  <-
      IxFun.linearWithOffset srcIxFun bt_size = do
        let row_size = product $ map ImpGen.dimSizeToExp $ drop 1 srcshape
        srcspace <- ImpGen.entryMemSpace <$> ImpGen.lookupMemory srcmem
        destspace <- ImpGen.entryMemSpace <$> ImpGen.lookupMemory destmem
        ImpGen.emit $ Imp.Copy
          destmem (bytes destoffset) destspace
          srcmem (bytes srcoffset) srcspace $
          (n * row_size) `Imp.withElemType` bt

  | otherwise = do
  global_thread_index <- newVName "copy_global_thread_index"

  -- Note that the shape of the destination and the source are
  -- necessarily the same.
  let shape = map Imp.sizeToExp srcshape
      shape_se = map (Imp.innerExp . ImpGen.dimSizeToExp) srcshape
      dest_is = unflattenIndex shape_se $ ImpGen.varIndex global_thread_index
      src_is = dest_is

  makeAllMemoryGlobal $ do
    (_, destspace, destidx) <- ImpGen.fullyIndexArray' destloc dest_is bt
    (_, srcspace, srcidx) <- ImpGen.fullyIndexArray' srcloc src_is bt

    let body = Imp.Write destmem destidx bt destspace Imp.Nonvolatile $
               Imp.index srcmem srcidx bt srcspace Imp.Nonvolatile

    destmem_size <- ImpGen.entryMemSize <$> ImpGen.lookupMemory destmem
    let writes_to = [Imp.MemoryUse destmem destmem_size]

    reads_from <- readsFromSet $
                  S.singleton srcmem <>
                  freeIn destIxFun <> freeIn srcIxFun <> freeIn destshape

    let kernel_size = Imp.innerExp n * product (drop 1 shape)
    (group_size, num_groups) <- computeMapKernelGroups kernel_size

    let bound_in_kernel = [global_thread_index]
    (body_uses, _) <- computeKernelUses (kernel_size, body) bound_in_kernel

    ImpGen.emit $ Imp.Op $ Imp.CallKernel $ Imp.Map Imp.MapKernel
      { Imp.mapKernelThreadNum = global_thread_index
      , Imp.mapKernelDesc = "copy"
      , Imp.mapKernelNumGroups = Imp.VarSize num_groups
      , Imp.mapKernelGroupSize = Imp.VarSize group_size
      , Imp.mapKernelSize = kernel_size
      , Imp.mapKernelUses = nub $ body_uses ++ writes_to ++ reads_from
      , Imp.mapKernelBody = body
      }

-- | We have no bulk copy operation (e.g. memmove) inside kernels, so
-- turn any copy into a loop.
inKernelCopy :: ImpGen.CopyCompiler InKernel Imp.KernelOp
inKernelCopy = ImpGen.copyElementWise

inKernelExpCompiler :: ImpGen.ExpCompiler InKernel Imp.KernelOp
inKernelExpCompiler _ (BasicOp (Assert _ _ (loc, locs))) =
  compilerLimitationS $
  unlines [ "Cannot compile assertion at " ++
            intercalate " -> " (reverse $ map locStr $ loc:locs) ++
            " inside parallel kernel."
          , "As a workaround, surround the expression with 'unsafe'."]
-- The static arrays stuff does not work inside kernels.
inKernelExpCompiler (ImpGen.Destination _ [dest]) (BasicOp (ArrayLit es _)) =
  forM_ (zip [0..] es) $ \(i,e) ->
  ImpGen.copyDWIMDest dest [fromIntegral (i::Int32)] e []
inKernelExpCompiler dest e =
  ImpGen.defCompileExp dest e

computeKernelUses :: FreeIn a =>
                     a -> [VName]
                  -> CallKernelGen ([Imp.KernelUse], [Imp.LocalMemoryUse])
computeKernelUses kernel_body bound_in_kernel = do
    let actually_free = freeIn kernel_body `S.difference` S.fromList bound_in_kernel

    -- Compute the variables that we need to pass to the kernel.
    reads_from <- readsFromSet actually_free

    -- Are we using any local memory?
    local_memory <- computeLocalMemoryUse actually_free
    return (nub reads_from, nub local_memory)

readsFromSet :: Names -> CallKernelGen [Imp.KernelUse]
readsFromSet free =
  fmap catMaybes $
  forM (S.toList free) $ \var -> do
    t <- lookupType var
    case t of
      Array {} -> return Nothing
      Mem _ (Space "local") -> return Nothing
      Mem memsize _ -> Just <$> (Imp.MemoryUse var <$>
                                 ImpGen.subExpToDimSize memsize)
      Prim bt ->
        isConstExp var >>= \case
          Just ce -> return $ Just $ Imp.ConstUse var ce
          Nothing | bt == Cert -> return Nothing
                  | otherwise  -> return $ Just $ Imp.ScalarUse var bt

computeLocalMemoryUse :: Names -> CallKernelGen [Imp.LocalMemoryUse]
computeLocalMemoryUse free =
  fmap catMaybes $
  forM (S.toList free) $ \var -> do
    t <- lookupType var
    case t of
      Mem memsize (Space "local") -> do
        memsize' <- localMemSize =<< ImpGen.subExpToDimSize memsize
        return $ Just (var, memsize')
      _ -> return Nothing

localMemSize :: Imp.MemSize -> CallKernelGen (Either Imp.MemSize Imp.KernelConstExp)
localMemSize (Imp.ConstSize x) =
  return $ Right $ ValueExp $ IntValue $ Int64Value x
localMemSize (Imp.VarSize v) = isConstExp v >>= \case
  Just e | isStaticExp e -> return $ Right e
  _ -> return $ Left $ Imp.VarSize v

-- | Only some constant expressions quality as *static* expressions,
-- which we can use for static memory allocation.  This is a bit of a
-- hack, as it is primarly motivated by what you can put as the size
-- when declaring an array in C.
isStaticExp :: Imp.KernelConstExp -> Bool
isStaticExp LeafExp{} = True
isStaticExp ValueExp{} = True
isStaticExp (BinOpExp Add{} x y) = isStaticExp x && isStaticExp y
isStaticExp (BinOpExp Sub{} x y) = isStaticExp x && isStaticExp y
isStaticExp (BinOpExp Mul{} x y) = isStaticExp x && isStaticExp y
isStaticExp _ = False

isConstExp :: VName -> CallKernelGen (Maybe Imp.KernelConstExp)
isConstExp v = do
  vtable <- asks ImpGen.envVtable
  let lookupConstExp name = constExp =<< hasExp =<< M.lookup name vtable
      constExp (Op (Inner (GetSize key _))) = Just $ LeafExp (Imp.SizeConst key) int32
      constExp e = primExpFromExp lookupConstExp e
  return $ lookupConstExp v
  where hasExp (ImpGen.ArrayVar e _) = e
        hasExp (ImpGen.ScalarVar e _) = e
        hasExp (ImpGen.MemVar e _) = e

-- | Change every memory block to be in the global address space,
-- except those who are in the local memory space.  This only affects
-- generated code - we still need to make sure that the memory is
-- actually present on the device (and declared as variables in the
-- kernel).
makeAllMemoryGlobal :: CallKernelGen a
                    -> CallKernelGen a
makeAllMemoryGlobal =
  local $ \env -> env { ImpGen.envVtable = M.map globalMemory $ ImpGen.envVtable env
                      , ImpGen.envDefaultSpace = Imp.Space "global"
                      }
  where globalMemory (ImpGen.MemVar _ entry)
          | ImpGen.entryMemSpace entry /= Space "local" =
              ImpGen.MemVar Nothing entry { ImpGen.entryMemSpace = Imp.Space "global" }
        globalMemory entry =
          entry

computeMapKernelGroups :: Imp.Exp -> CallKernelGen (VName, VName)
computeMapKernelGroups kernel_size = do
  group_size <- newVName "group_size"
  num_groups <- newVName "num_groups"
  let group_size_var = Imp.var group_size int32
  ImpGen.emit $ Imp.DeclareScalar group_size int32
  ImpGen.emit $ Imp.DeclareScalar num_groups int32
  ImpGen.emit $ Imp.Op $ Imp.GetSize group_size group_size Imp.SizeGroup
  ImpGen.emit $ Imp.SetScalar num_groups $
    kernel_size `quotRoundingUp` Imp.ConvOpExp (SExt Int32 Int32) group_size_var
  return (group_size, num_groups)

isMapTransposeKernel :: PrimType -> ImpGen.MemLocation -> ImpGen.MemLocation
                     -> Maybe (Imp.Exp, Imp.Exp,
                               Imp.Exp, Imp.Exp, Imp.Exp,
                               Imp.Exp, Imp.Exp)
isMapTransposeKernel bt
  (ImpGen.MemLocation _ _ destIxFun)
  (ImpGen.MemLocation _ _ srcIxFun)
  | Just (dest_offset, perm_and_destshape) <- IxFun.rearrangeWithOffset destIxFun bt_size,
    (perm, destshape) <- unzip perm_and_destshape,
    srcshape' <- IxFun.shape srcIxFun,
    Just src_offset <- IxFun.linearWithOffset srcIxFun bt_size,
    Just (r1, r2, _) <- isMapTranspose perm =
    isOk (product srcshape') (product destshape) destshape swap r1 r2 dest_offset src_offset
  | Just dest_offset <- IxFun.linearWithOffset destIxFun bt_size,
    Just (src_offset, perm_and_srcshape) <- IxFun.rearrangeWithOffset srcIxFun bt_size,
    (perm, srcshape) <- unzip perm_and_srcshape,
    destshape' <- IxFun.shape destIxFun,
    Just (r1, r2, _) <- isMapTranspose perm =
    isOk (product srcshape) (product destshape') srcshape id r1 r2 dest_offset src_offset
  | otherwise =
    Nothing
  where bt_size = primByteSize bt
        swap (x,y) = (y,x)

        isOk src_elems dest_elems shape f r1 r2 dest_offset src_offset = do
          let (num_arrays, size_x, size_y) = getSizes shape f r1 r2
          return (dest_offset, src_offset,
                  num_arrays, size_x, size_y,
                  src_elems, dest_elems)

        getSizes shape f r1 r2 =
          let (mapped, notmapped) = splitAt r1 shape
              (pretrans, posttrans) = f $ splitAt r2 notmapped
          in (product mapped, product pretrans, product posttrans)

writeParamToLocalMemory :: Typed (MemBound u) =>
                           Imp.Exp -> (VName, t) -> Param (MemBound u)
                        -> ImpGen.ImpM lore op ()
writeParamToLocalMemory i (mem, _) param
  | Prim t <- paramType param =
      ImpGen.emit $
      Imp.Write mem (bytes i') bt (Space "local") Imp.Volatile $
      Imp.var (paramName param) t
  | otherwise =
      return ()
  where i' = i * Imp.LeafExp (Imp.SizeOf bt) int32
        bt = elemType $ paramType param

readParamFromLocalMemory :: Typed (MemBound u) =>
                            VName -> Imp.Exp -> Param (MemBound u) -> (VName, t)
                         -> ImpGen.ImpM lore op ()
readParamFromLocalMemory index i param (l_mem, _)
  | Prim _ <- paramType param =
      ImpGen.emit $
      Imp.SetScalar (paramName param) $
      Imp.index l_mem (bytes i') bt (Space "local") Imp.Volatile
  | otherwise =
      ImpGen.emit $
      Imp.SetScalar index i
  where i' = i * Imp.LeafExp (Imp.SizeOf bt) int32
        bt = elemType $ paramType param

computeThreadChunkSize :: SplitOrdering
                       -> Imp.Exp
                       -> Imp.Count Imp.Elements
                       -> Imp.Count Imp.Elements
                       -> VName
                       -> ImpGen.ImpM lore op ()
computeThreadChunkSize (SplitStrided stride) thread_index elements_per_thread num_elements chunk_var = do
  stride' <- ImpGen.compileSubExp stride
  ImpGen.emit $ Imp.SetScalar chunk_var $ Imp.BinOpExp (SMin Int32)
    (Imp.innerExp elements_per_thread) $
    (Imp.innerExp num_elements - thread_index)
    `quotRoundingUp`
    stride'

computeThreadChunkSize SplitContiguous thread_index elements_per_thread num_elements chunk_var = do
  starting_point <- newVName "starting_point"
  remaining_elements <- newVName "remaining_elements"

  ImpGen.emit $
    Imp.DeclareScalar starting_point int32
  ImpGen.emit $
    Imp.SetScalar starting_point $
    thread_index * Imp.innerExp elements_per_thread

  ImpGen.emit $
    Imp.DeclareScalar remaining_elements int32
  ImpGen.emit $
    Imp.SetScalar remaining_elements $
    Imp.innerExp num_elements - Imp.var starting_point int32

  let no_remaining_elements = Imp.CmpOpExp (CmpSle Int32)
                              (Imp.var remaining_elements int32) 0
      beyond_bounds = Imp.CmpOpExp (CmpSle Int32)
                      (Imp.innerExp num_elements)
                      (Imp.var starting_point int32)

  ImpGen.emit $
    Imp.If (Imp.BinOpExp LogOr no_remaining_elements beyond_bounds)
    (Imp.SetScalar chunk_var 0)
    (Imp.If is_last_thread
     (Imp.SetScalar chunk_var $ Imp.innerExp last_thread_elements)
     (Imp.SetScalar chunk_var $ Imp.innerExp elements_per_thread))
  where last_thread_elements =
          num_elements - Imp.elements thread_index * elements_per_thread
        is_last_thread =
          Imp.CmpOpExp (CmpSlt Int32)
          (Imp.innerExp num_elements)
          ((thread_index + 1) * Imp.innerExp elements_per_thread)

inBlockScan :: Imp.Exp
           -> Imp.Exp
           -> Imp.Exp
           -> VName
           -> [(VName, t)]
           -> Lambda InKernel
           -> InKernelGen ()
inBlockScan lockstep_width block_size active local_id acc_local_mem scan_lam = ImpGen.everythingVolatile $ do
  skip_threads <- newVName "skip_threads"
  let in_block_thread_active =
        Imp.CmpOpExp (CmpSle Int32) (Imp.var skip_threads int32) in_block_id
      (scan_lam_i, other_index_param, actual_params) =
        partitionChunkedKernelLambdaParameters $ lambdaParams scan_lam
      (x_params, y_params) =
        splitAt (length actual_params `div` 2) actual_params
  read_operands <-
    ImpGen.collect $
    zipWithM_ (readParamFromLocalMemory (paramName other_index_param) $
               Imp.var local_id int32 - Imp.var skip_threads int32)
    x_params acc_local_mem
  scan_y_dest <- ImpGen.destinationFromParams y_params

  -- Set initial y values
  read_my_initial <- ImpGen.collect $
                     zipWithM_ (readParamFromLocalMemory scan_lam_i $ Imp.var local_id int32)
                     y_params acc_local_mem
  ImpGen.emit $ Imp.If active read_my_initial mempty

  op_to_y <- ImpGen.collect $ ImpGen.compileBody scan_y_dest $ lambdaBody scan_lam
  write_operation_result <-
    ImpGen.collect $
    zipWithM_ (writeParamToLocalMemory $ Imp.var local_id int32)
    acc_local_mem y_params
  let andBlockActive = Imp.BinOpExp LogAnd active
      maybeBarrier = Imp.If (Imp.CmpOpExp (CmpSle Int32) lockstep_width (Imp.var skip_threads int32))
                     (Imp.Op Imp.Barrier) mempty

  ImpGen.emit $
    Imp.Comment "in-block scan (hopefully no barriers needed)" $
    Imp.DeclareScalar skip_threads int32 <>
    Imp.SetScalar skip_threads 1 <>
    Imp.While (Imp.CmpOpExp (CmpSlt Int32) (Imp.var skip_threads int32) block_size)
    (Imp.If (andBlockActive in_block_thread_active)
      (Imp.Comment "read operands" read_operands <>
       Imp.Comment "perform operation" op_to_y) mempty <>

     maybeBarrier <>

     Imp.If (andBlockActive in_block_thread_active)
      (Imp.Comment "write result" write_operation_result) mempty <>
     maybeBarrier <>
     Imp.SetScalar skip_threads (Imp.var skip_threads int32 * 2))
  where block_id = Imp.BinOpExp (SQuot Int32) (Imp.var local_id int32) block_size
        in_block_id = Imp.var local_id int32 - block_id * block_size

data KernelConstants = KernelConstants
                       { kernelGlobalThreadId :: VName
                       , kernelLocalThreadId :: VName
                       , kernelGroupId :: VName
                       , kernelGroupSize :: Imp.DimSize
                       , _kernelNumThreads :: Imp.DimSize
                       , kernelWaveSize :: Imp.DimSize
                       , kernelDimensions :: [(VName, Imp.Exp)]
                       , kernelThreadActive :: Imp.Exp
                       , kernelStreamed :: [(VName, Imp.DimSize)]
                       -- ^ Chunk sizez and their maximum size.  Hint
                       -- for unrolling.
                       }

-- FIXME: wing a KernelConstants structure for use in Replicate
-- compilation.  This cannot be the best way to do this...
simpleKernelConstants :: MonadFreshNames m =>
                         Maybe Int -> String
                      -> m KernelConstants
simpleKernelConstants tag desc = do
  thread_gtid <- maybe (newVName $ desc ++ "_gtid")
                       (return . VName (nameFromString $ desc ++ "_gtid")) tag
  thread_ltid <- newVName $ desc ++ "_ltid"
  thread_gid <- newVName $ desc ++ "_gid"
  return $ KernelConstants
    thread_gtid thread_ltid thread_gid
    (Imp.ConstSize 0) (Imp.ConstSize 0) (Imp.ConstSize 0)
    [] (Imp.ValueExp $ BoolValue True) mempty

compileKernelBody :: ImpGen.Destination
                  -> KernelConstants
                  -> KernelBody InKernel
                  -> InKernelGen ()
compileKernelBody (ImpGen.Destination _ dest) constants kbody =
  compileKernelStms constants (stmsToList $ kernelBodyStms kbody) $
  zipWithM_ (compileKernelResult constants) dest $
  kernelBodyResult kbody

compileNestedKernelBody :: KernelConstants
                        -> ImpGen.Destination
                        -> Body InKernel
                        -> InKernelGen ()
compileNestedKernelBody constants (ImpGen.Destination _ dest) kbody =
  compileKernelStms constants (stmsToList $ bodyStms kbody) $
  zipWithM_ ImpGen.compileSubExpTo dest $ bodyResult kbody

compileKernelStms :: KernelConstants -> [Stm InKernel]
                  -> InKernelGen a
                  -> InKernelGen a
compileKernelStms constants ungrouped_bnds m =
  compileGroupedKernelStms' $ groupStmsByGuard constants ungrouped_bnds
  where compileGroupedKernelStms' [] = m
        compileGroupedKernelStms' ((g, bnds):rest_bnds) =
          ImpGen.declaringScopes
          (map ((Just . stmExp) &&& (castScope . scopeOf)) bnds) $ do
            protect g $ mapM_ compileKernelStm bnds
            compileGroupedKernelStms' rest_bnds

        protect Nothing body_m =
          body_m
        protect (Just (Imp.ValueExp (BoolValue True))) body_m =
          body_m
        protect (Just g) body_m = do
          body <- allThreads constants body_m
          ImpGen.emit $ Imp.If g body mempty

        compileKernelStm (Let pat _ e) = do
          dest <- ImpGen.destinationFromPattern pat
          ImpGen.compileExp dest e

groupStmsByGuard :: KernelConstants
                     -> [Stm InKernel]
                     -> [(Maybe Imp.Exp, [Stm InKernel])]
groupStmsByGuard constants bnds =
  map collapse $ groupBy sameGuard $ zip (map bindingGuard bnds) bnds
  where bindingGuard (Let _ _ Op{}) = Nothing
        bindingGuard _ = Just $ kernelThreadActive constants

        sameGuard (g1, _) (g2, _) = g1 == g2

        collapse [] =
          (Nothing, [])
        collapse l@((g,_):_) =
          (g, map snd l)

compileKernelExp :: KernelConstants -> ImpGen.Destination -> KernelExp InKernel
                 -> InKernelGen ()

compileKernelExp _ (ImpGen.Destination _ dests) (Barrier ses) = do
  zipWithM_ ImpGen.compileSubExpTo dests ses
  ImpGen.emit $ Imp.Op Imp.Barrier

compileKernelExp _ dest (SplitSpace o w i elems_per_thread)
  | ImpGen.Destination _ [ImpGen.ScalarDestination size] <- dest = do
      num_elements <- Imp.elements <$> ImpGen.compileSubExp w
      i' <- ImpGen.compileSubExp i
      elems_per_thread' <- Imp.elements <$> ImpGen.compileSubExp elems_per_thread
      computeThreadChunkSize o i' elems_per_thread' num_elements size

compileKernelExp constants dest (Combine (CombineSpace scatter cspace) ts aspace body) = do
  -- First we compute how many times we have to iterate to cover
  -- cspace with our group size.  It is a fairly common case that
  -- we statically know that this requires 1 iteration, so we
  -- could detect it and not generate a loop in that case.
  -- However, it seems to have no impact on performance (an extra
  -- conditional jump), so for simplicity we just always generate
  -- the loop.
  let cspace_dims = map (streamBounded . snd) cspace
      num_iters = product cspace_dims `quotRoundingUp`
                  Imp.sizeToExp (kernelGroupSize constants)

  iter <- newVName "comb_iter"
  cid <- newVName "flat_comb_id"

  one_iteration <- ImpGen.collect $
    ImpGen.declaringPrimVars (zip (map fst cspace) $ repeat int32) $
    ImpGen.declaringPrimVar cid int32 $ do

      -- Compute the *flat* array index.
      ImpGen.emit $ Imp.SetScalar cid $
        Imp.var iter int32 * Imp.sizeToExp (kernelGroupSize constants) +
        Imp.var (kernelLocalThreadId constants) int32

      -- Turn it into a nested array index.
      forM_ (zip (map fst cspace) $ unflattenIndex cspace_dims (Imp.var cid int32)) $ \(v, x) ->
        ImpGen.emit $ Imp.SetScalar v x

      -- Construct the body.  This is mostly about the book-keeping
      -- for the scatter-like part.
      let (scatter_ws, scatter_ns, _scatter_vs) = unzip3 scatter
          scatter_ws_repl = concat $ zipWith replicate scatter_ns scatter_ws
          (scatter_dests, normal_dests) =
            splitAt (sum scatter_ns) $ ImpGen.valueDestinations dest
          (res_is, res_vs, res_normal) =
            splitAt3 (sum scatter_ns) (sum scatter_ns) $ bodyResult body
          scatter_is = map (pure . DimFix . ImpGen.compileSubExpOfType int32) res_is
          scatter_dests_repl = concat $ zipWith replicate scatter_ns scatter_dests
      (scatter_dests', normal_dests') <-
        case (sequence $ zipWith3 index scatter_is ts scatter_dests_repl,
              zipWithM (index local_index) (drop (sum scatter_ns*2) ts) normal_dests) of
          (Just x, Just y) -> return (x, y)
          _ -> fail "compileKernelExp combine: invalid destination."
      body' <- allThreads constants $
        ImpGen.compileStms (freeIn $ bodyResult body) (stmsToList $ bodyStms body) $ do

        forM_ (zip4 scatter_ws_repl res_is res_vs scatter_dests') $
          \(w, res_i, res_v, scatter_dest) -> do
            let res_i' = ImpGen.compileSubExpOfType int32 res_i
                w'     = ImpGen.compileSubExpOfType int32 w
                -- We have to check that 'res_i' is in-bounds wrt. an array of size 'w'.
                in_bounds = BinOpExp LogAnd (CmpOpExp (CmpSle Int32) 0 res_i')
                                            (CmpOpExp (CmpSlt Int32) res_i' w')
            when_in_bounds <- ImpGen.collect $ ImpGen.compileSubExpTo scatter_dest res_v
            ImpGen.emit $ Imp.If in_bounds when_in_bounds mempty

        zipWithM_ ImpGen.compileSubExpTo normal_dests' res_normal

      -- Execute the body if we are within bounds.
      ImpGen.emit $
        Imp.If (Imp.BinOpExp LogAnd (isActive cspace) (isActive aspace)) body' mempty

  ImpGen.emit $ Imp.For iter Int32 num_iters one_iteration
  ImpGen.emit $ Imp.Op Imp.Barrier

    where streamBounded (Var v)
            | Just x <- lookup v $ kernelStreamed constants =
                Imp.sizeToExp x
          streamBounded se = ImpGen.compileSubExpOfType int32 se

          local_index = map (DimFix . ImpGen.varIndex . fst) cspace

          index i t (ImpGen.ArrayDestination (Just loc)) =
            let space_dims = map (ImpGen.varIndex . fst) cspace
                t_dims = map (ImpGen.compileSubExpOfType int32) $ arrayDims t
            in Just $ ImpGen.ArrayDestination $
               Just $ ImpGen.sliceArray loc $
               fullSliceNum (space_dims++t_dims) i
          index _ _ _ = Nothing

compileKernelExp constants (ImpGen.Destination _ dests) (GroupReduce w lam input) = do
  skip_waves <- newVName "skip_waves"
  w' <- ImpGen.compileSubExp w

  let local_tid = kernelLocalThreadId constants
      (_nes, arrs) = unzip input
      (reduce_i, reduce_j_param, actual_reduce_params) =
        partitionChunkedKernelLambdaParameters $ lambdaParams lam
      (reduce_acc_params, reduce_arr_params) =
        splitAt (length input) actual_reduce_params
      reduce_j = paramName reduce_j_param

  offset <- newVName "offset"
  ImpGen.emit $ Imp.DeclareScalar offset int32

  ImpGen.Destination _ reduce_acc_targets <-
    ImpGen.destinationFromParams reduce_acc_params

  ImpGen.declaringPrimVar skip_waves int32 $
    ImpGen.declaringLParams (lambdaParams lam) $ do

    ImpGen.emit $ Imp.SetScalar reduce_i $ Imp.var local_tid int32

    let setOffset x =
          Imp.SetScalar offset x <>
          Imp.SetScalar reduce_j (Imp.var local_tid int32 + Imp.var offset int32)
    ImpGen.emit $ setOffset 0

    set_init_params <- ImpGen.collect $
      zipWithM_ (readReduceArgument offset) reduce_acc_params arrs
    ImpGen.emit $
      Imp.If (Imp.CmpOpExp (CmpSlt Int32) (Imp.var local_tid int32) w')
      set_init_params mempty

    let read_reduce_args = zipWithM_ (readReduceArgument offset)
                           reduce_arr_params arrs
        reduce_acc_dest = ImpGen.Destination Nothing reduce_acc_targets
        do_reduce = do ImpGen.comment "read array element" read_reduce_args
                       ImpGen.compileBody reduce_acc_dest $ lambdaBody lam
                       zipWithM_ (writeReduceOpResult local_tid)
                         reduce_acc_params arrs

    in_wave_reduce <- ImpGen.collect $ ImpGen.everythingVolatile do_reduce
    cross_wave_reduce <- ImpGen.collect do_reduce

    let wave_size = Imp.sizeToExp $ kernelWaveSize constants
        group_size = Imp.sizeToExp $ kernelGroupSize constants
        wave_id = Imp.var local_tid int32 `quot` wave_size
        in_wave_id = Imp.var local_tid int32 - wave_id * wave_size
        num_waves = (group_size + wave_size - 1) `quot` wave_size
        arg_in_bounds = Imp.CmpOpExp (CmpSlt Int32)
                        (Imp.var reduce_j int32) w'

        doing_in_wave_reductions =
          Imp.CmpOpExp (CmpSlt Int32) (Imp.var offset int32) wave_size
        apply_in_in_wave_iteration =
          Imp.CmpOpExp (CmpEq int32)
          (Imp.BinOpExp (And Int32) in_wave_id (2 * Imp.var offset int32 - 1)) 0
        in_wave_reductions =
          setOffset 1 <>
          Imp.While doing_in_wave_reductions
            (Imp.If (Imp.BinOpExp LogAnd arg_in_bounds apply_in_in_wave_iteration)
             in_wave_reduce mempty <>
             setOffset (Imp.var offset int32 * 2))

        doing_cross_wave_reductions =
          Imp.CmpOpExp (CmpSlt Int32) (Imp.var skip_waves int32) num_waves
        is_first_thread_in_wave =
          Imp.CmpOpExp (CmpEq int32) in_wave_id 0
        wave_not_skipped =
          Imp.CmpOpExp (CmpEq int32)
          (Imp.BinOpExp (And Int32) wave_id (2 * Imp.var skip_waves int32 - 1))
          0
        apply_in_cross_wave_iteration =
          Imp.BinOpExp LogAnd arg_in_bounds $
          Imp.BinOpExp LogAnd is_first_thread_in_wave wave_not_skipped
        cross_wave_reductions =
          Imp.SetScalar skip_waves 1 <>
          Imp.While doing_cross_wave_reductions
            (Imp.Op Imp.Barrier <>
             setOffset (Imp.var skip_waves int32 * wave_size) <>
             Imp.If apply_in_cross_wave_iteration
             cross_wave_reduce mempty <>
             Imp.SetScalar skip_waves (Imp.var skip_waves int32 * 2))

    ImpGen.emit $
      in_wave_reductions <> cross_wave_reductions

    forM_ (zip dests reduce_acc_params) $ \(dest, reduce_acc_param) ->
      ImpGen.copyDWIMDest dest [] (Var $ paramName reduce_acc_param) []
  where readReduceArgument offset param arr
          | Prim _ <- paramType param =
              ImpGen.copyDWIM (paramName param) [] (Var arr) [i]
          | otherwise =
              return ()
          where i = ImpGen.varIndex (kernelLocalThreadId constants) + ImpGen.varIndex offset

        writeReduceOpResult i param arr
          | Prim _ <- paramType param =
              ImpGen.copyDWIM arr [ImpGen.varIndex i] (Var $ paramName param) []
          | otherwise =
              return ()

compileKernelExp constants _ (GroupScan w lam input) = do
  renamed_lam <- renameLambda lam
  w' <- ImpGen.compileSubExp w

  when (any (not . primType . paramType) $ lambdaParams lam) $
    compilerLimitationS "Cannot compile parallel scans with array element type."

  let local_tid = kernelLocalThreadId constants
      (_nes, arrs) = unzip input
      (lam_i, other_index_param, actual_params) =
        partitionChunkedKernelLambdaParameters $ lambdaParams lam
      (x_params, y_params) =
        splitAt (length input) actual_params

  ImpGen.declaringLParams (lambdaParams lam++lambdaParams renamed_lam) $ do
    ImpGen.emit $ Imp.SetScalar lam_i $ Imp.var local_tid int32

    acc_local_mem <- flip zip (repeat ()) <$>
                     mapM (fmap (ImpGen.memLocationName . ImpGen.entryArrayLocation) .
                           ImpGen.lookupArray) arrs

    -- The scan works by splitting the group into blocks, which are
    -- scanned separately.  Typically, these blocks are smaller than
    -- the lockstep width, which enables barrier-free execution inside
    -- them.
    --
    -- We hardcode the block size here.  The only requirement is that
    -- it should not be less than the square root of the group size.
    -- With 32, we will work on groups of size 1024 or smaller, which
    -- fits every device Troels has seen.  Still, it would be nicer if
    -- it were a runtime parameter.  Some day.
    let block_size = Imp.ValueExp $ IntValue $ Int32Value 32
        simd_width = Imp.sizeToExp $ kernelWaveSize constants
        block_id = Imp.var local_tid int32 `quot` block_size
        in_block_id = Imp.var local_tid int32 - block_id * block_size
        doInBlockScan active = inBlockScan simd_width block_size active local_tid acc_local_mem
        lid_in_bounds = Imp.CmpOpExp (CmpSlt Int32) (Imp.var local_tid int32) w'

    doInBlockScan lid_in_bounds lam
    ImpGen.emit $ Imp.Op Imp.Barrier

    pack_block_results <-
      ImpGen.collect $
      zipWithM_ (writeParamToLocalMemory block_id) acc_local_mem y_params

    let last_in_block =
          Imp.CmpOpExp (CmpEq int32) in_block_id $ block_size - 1
    ImpGen.comment
      "last thread of block 'i' writes its result to offset 'i'" $
      ImpGen.emit $ Imp.If (Imp.BinOpExp LogAnd last_in_block lid_in_bounds) pack_block_results mempty

    ImpGen.emit $ Imp.Op Imp.Barrier

    let is_first_block = Imp.CmpOpExp (CmpEq int32) block_id 0
    ImpGen.comment
      "scan the first block, after which offset 'i' contains carry-in for warp 'i+1'" $
      doInBlockScan (Imp.BinOpExp LogAnd is_first_block lid_in_bounds) renamed_lam

    ImpGen.emit $ Imp.Op Imp.Barrier

    read_carry_in <-
      ImpGen.collect $
      zipWithM_ (readParamFromLocalMemory
                 (paramName other_index_param) (block_id - 1))
      x_params acc_local_mem

    y_dest <- ImpGen.destinationFromParams y_params
    op_to_y <- ImpGen.collect $ ImpGen.compileBody y_dest $ lambdaBody lam
    write_final_result <- ImpGen.collect $
      zipWithM_ (writeParamToLocalMemory $ Imp.var local_tid int32) acc_local_mem y_params

    ImpGen.comment "carry-in for every block except the first" $
      ImpGen.emit $ Imp.If (Imp.BinOpExp LogOr
                             is_first_block
                             (Imp.UnOpExp Not lid_in_bounds)) mempty $
      Imp.Comment "read operands" read_carry_in <>
      Imp.Comment "perform operation" op_to_y <>
      Imp.Comment "write final result" write_final_result

    ImpGen.emit $ Imp.Op Imp.Barrier

    ImpGen.comment "restore correct values for first block" $
      ImpGen.emit $ Imp.If is_first_block write_final_result mempty


compileKernelExp constants (ImpGen.Destination _ final_targets) (GroupStream w maxchunk lam accs _arrs) = do
  let GroupStreamLambda block_size block_offset acc_params arr_params body = lam
      block_offset' = Imp.var block_offset int32
  w' <- ImpGen.compileSubExp w
  max_block_size <- ImpGen.compileSubExp maxchunk
  acc_dest <- ImpGen.destinationFromParams acc_params

  ImpGen.declaringLParams (acc_params++arr_params) $ do
    zipWithM_ ImpGen.compileSubExpTo (ImpGen.valueDestinations acc_dest) accs
    ImpGen.declaringPrimVar block_size int32 $
      -- If the GroupStream is morally just a do-loop, generate simpler code.
      case mapM isSimpleThreadInSpace $ stmsToList $ bodyStms body of
        Just stms' | ValueExp x <- max_block_size, oneIsh x -> do
          let body' = body { bodyStms = stmsFromList stms' }
          body'' <- ImpGen.withPrimVar block_offset int32 $
                    allThreads constants $ ImpGen.emit =<<
                    ImpGen.compileLoopBody (map paramName acc_params) body'
          ImpGen.emit $ Imp.SetScalar block_size 1

          -- Check if loop is candidate for unrolling.
          let loop =
                case w of
                  Var w_var | Just w_bound <- lookup w_var $ kernelStreamed constants,
                              w_bound /= Imp.ConstSize 1 ->
                              -- Candidate for unrolling, so generate two loops.
                              Imp.If (CmpOpExp (CmpEq int32) w' (Imp.sizeToExp w_bound))
                              (Imp.For block_offset Int32 (Imp.sizeToExp w_bound) body'')
                              (Imp.For block_offset Int32 w' body'')
                  _ -> Imp.For block_offset Int32 w' body''

          ImpGen.emit $
            if kernelThreadActive constants == Imp.ValueExp (BoolValue True)
            then loop
            else Imp.If (kernelThreadActive constants) loop mempty

        _ -> ImpGen.declaringPrimVar block_offset int32 $ do
          body' <- streaming constants block_size maxchunk $
                   ImpGen.compileBody acc_dest body

          ImpGen.emit $ Imp.SetScalar block_offset 0

          let not_at_end =
                Imp.CmpOpExp (CmpSlt Int32) block_offset' w'
              set_block_size =
                Imp.If (Imp.CmpOpExp (CmpSlt Int32)
                         (w' - block_offset')
                         max_block_size)
                (Imp.SetScalar block_size (w' - block_offset'))
                (Imp.SetScalar block_size max_block_size)
              increase_offset =
                Imp.SetScalar block_offset $
                block_offset' + max_block_size

          -- Three cases to consider for simpler generated code based
          -- on max block size: (0) if full input size, do not
          -- generate a loop; (1) if one, generate for-loop (2)
          -- otherwise, generate chunked while-loop.
          ImpGen.emit $
            if max_block_size == w' then
              Imp.SetScalar block_size w' <> body'
            else if max_block_size == Imp.ValueExp (value (1::Int32)) then
                   Imp.SetScalar block_size w' <>
                   Imp.For block_offset Int32 w' body'
                 else
                   Imp.While not_at_end $
                   set_block_size <> body' <> increase_offset

    zipWithM_ ImpGen.compileSubExpTo final_targets $
      map (Var . paramName) acc_params

      where isSimpleThreadInSpace (Let _ _ Op{}) = Nothing
            isSimpleThreadInSpace bnd = Just bnd

compileKernelExp _ _ (GroupGenReduce w [a] op bucket [v] _)
  | [Prim t] <- lambdaReturnType op,
    primBitSize t == 32 = do
  -- If we have only one array and one non-array value (this is a
  -- one-to-one correspondance) then we need only one
  -- update. If operator has an atomic implementation we use
  -- that, otherwise it is still a binary operator which can
  -- be implemented by atomic compare-and-swap if 32 bits.

  -- Common variables.
  old <- newVName "old"
  old_bits <- newVName "old_bits"
  ImpGen.emit $ Imp.DeclareScalar old t
  ImpGen.emit $ Imp.DeclareScalar old_bits int32
  bucket' <- mapM ImpGen.compileSubExp bucket
  w' <- mapM ImpGen.compileSubExp w

  (arr', _a_space, bucket_offset) <- ImpGen.fullyIndexArray a bucket'

  case opHasAtomicSupport old arr' bucket_offset op of
    Just f -> do
      val' <- ImpGen.compileSubExp v

      ImpGen.emit $
        Imp.If (indexInBounds bucket' w')
        (Imp.Op $ f val')
        Imp.Skip

    Nothing -> do
      -- Code generation target:
      --
      -- old = d_his[idx];
      -- do {
      --   assumed = old;
      --   tmp = OP::apply(val, assumed);
      --   old = atomicCAS(&d_his[idx], assumed, tmp);
      -- } while(assumed != old);
      assumed <- newVName "assumed"
      run_loop <- newVName "run_loop"
      ImpGen.emit $ Imp.DeclareScalar assumed t
      ImpGen.emit $ Imp.DeclareScalar run_loop int32

      read_old <- ImpGen.collect $
        ImpGen.copyDWIMDest (ImpGen.ScalarDestination old) [] (Var a) bucket'

      ImpGen.emit $
        Imp.If (indexInBounds bucket' w')
        -- True branch: bucket in-bounds -> enter loop
        (Imp.SetScalar run_loop 1 <> read_old)
        -- False branch: bucket out-of-bounds -> skip loop
        (Imp.SetScalar run_loop 0)

        -- Preparing parameters
      let (acc_p:arr_p:_) = lambdaParams op

      -- Store result from operator in accumulators
      dests <- ImpGen.destinationFromParams [acc_p]

      -- Critical section
      ImpGen.declaringLParams (lambdaParams op) $ do
        bind_acc_param <- ImpGen.collect $
          ImpGen.copyDWIMDest (ImpGen.ScalarDestination $ paramName acc_p) [] v []

        let bind_arr_param =
              Imp.SetScalar (paramName arr_p) $ Imp.var assumed t

        op_body <- ImpGen.collect $
          ImpGen.compileBody dests $ lambdaBody op

        -- While-loop: Try to insert your value
        let (toBits, fromBits) =
              case t of FloatType Float32 -> (\x -> Imp.FunExp "to_bits32" [x] int32,
                                              \x -> Imp.FunExp "from_bits32" [x] t)
                        _                 -> (id, id)
        ImpGen.emit $ Imp.While (Imp.var run_loop int32)
          (Imp.SetScalar assumed (Imp.var old t) <>
           bind_acc_param <> bind_arr_param <> op_body
           <>
           (Imp.Op $
               Imp.Atomic $
                 Imp.AtomicCmpXchg old_bits arr' bucket_offset
                   (toBits (Imp.var assumed int32)) (toBits (Imp.var (paramName acc_p) int32)))
           <>
           Imp.SetScalar old (fromBits (Imp.var old_bits int32))
           <>
            Imp.If
              (Imp.CmpOpExp
                (CmpEq int32) (toBits $ Imp.var assumed t) (Imp.var old_bits int32))
              -- True branch:
              (Imp.SetScalar run_loop 0)
              -- False branch:
              Imp.Skip
          )

    where opHasAtomicSupport old arr' bucket' lam = do
            let atomic f = Imp.Atomic . f old arr' bucket'
                atomics = [ (Add Int32, Imp.AtomicAdd)
                          , (SMax Int32, Imp.AtomicSMax)
                          , (SMin Int32, Imp.AtomicSMin)
                          , (UMax Int32, Imp.AtomicUMax)
                          , (UMin Int32, Imp.AtomicUMin)
                          , (And Int32, Imp.AtomicAnd)
                          , (Or Int32, Imp.AtomicOr)
                          , (Xor Int32, Imp.AtomicXor)
                          ]
            [BasicOp (BinOp bop _ _)] <-
              Just $ map stmExp $ stmsToList $ bodyStms $ lambdaBody lam
            atomic <$> lookup bop atomics

compileKernelExp _ _ (GroupGenReduce w arrs op bucket values locks) = do
  old <- newVName "old"
  tmp <- newVName "tmp"
  loop_done <- newVName "loop_done"
  ImpGen.emit $
    Imp.DeclareScalar old int32 <>
    Imp.DeclareScalar tmp int32 <>
    Imp.DeclareScalar loop_done int32

  -- Check if bucket is in-bounds
  bucket' <- mapM ImpGen.compileSubExp bucket
  w' <- mapM ImpGen.compileSubExp w

  -- Correctly index into locks.
  (locks', _locks_space, locks_offset) <-
    ImpGen.fullyIndexArray locks bucket'

  ImpGen.emit $
    Imp.If (indexInBounds bucket' w')
    -- True branch: bucket in-bounds -> enter loop
    (Imp.SetScalar loop_done 0)
    -- False branch: bucket out-of-bounds -> skip loop
    (Imp.SetScalar loop_done 1)

  -- Preparing parameters
  let (acc_params, arr_params) =
        splitAt (length values) $ lambdaParams op

  -- Store result from operator in accumulators
  dests <- ImpGen.destinationFromParams acc_params

  -- Critical section
  ImpGen.declaringLParams (lambdaParams op) $ do
    let try_acquire_lock =
          Imp.Op $ Imp.Atomic $
          Imp.AtomicXchg old locks' locks_offset 1
        lock_acquired =
          Imp.CmpOpExp (CmpEq int32) (Imp.var old int32) 0
        loop_cond =
          Imp.CmpOpExp (CmpEq int32) (Imp.var loop_done int32) 0
        break_loop =
          Imp.SetScalar loop_done 1

    -- We copy the current value and the new value to the parameters
    -- unless they are array-typed.  If they are arrays, then the
    -- index functions should already be set up correctly, so there is
    -- nothing more to do.
    bind_acc_params <- ImpGen.collect $
      forM_ (zip acc_params arrs) $ \(acc_p, arr) ->
      when (primType (paramType acc_p)) $
      ImpGen.copyDWIMDest (ImpGen.ScalarDestination $ paramName acc_p) [] (Var arr) bucket'

    bind_arr_params <- ImpGen.collect $
      forM_ (zip arr_params values) $ \(arr_p, val) ->
      when (primType (paramType arr_p)) $
      ImpGen.copyDWIMDest (ImpGen.ScalarDestination $ paramName arr_p) [] val []

    op_body <- ImpGen.collect $
      ImpGen.compileBody dests $ lambdaBody op

    do_gen_reduce <- ImpGen.collect $
      zipWithM_ (writeArray bucket') arrs $ map (Var . paramName) acc_params

    release_lock <- ImpGen.collect $
      ImpGen.copyDWIM locks bucket' (intConst Int32 0) []

    -- While-loop: Try to insert your value
    ImpGen.emit $ Imp.While loop_cond
      (try_acquire_lock <>
        Imp.If lock_acquired
         -- True branch
         (bind_acc_params <> bind_arr_params <> op_body <> do_gen_reduce <> release_lock <> break_loop)
         -- False branch
         Imp.Skip
         <>
        Imp.Op Imp.MemFence
      )
  where writeArray i arr val =
          ImpGen.copyDWIM arr i val []

compileKernelExp _ dest e =
  compilerBugS $ unlines ["Invalid target", "  " ++ show dest,
                          "for kernel expression", "  " ++ pretty e]

-- Requires that the lists are of equal length, otherwise
-- zip with truncate the longer list.
indexInBounds :: [Imp.Exp] -> [Imp.Exp] -> Imp.Exp
indexInBounds inds bounds =
  foldl1 (Imp.BinOpExp LogAnd) $ zipWith checkBound inds bounds
  where checkBound ind bound =
          Imp.BinOpExp LogAnd
           (Imp.CmpOpExp (CmpSle Int32) 0 ind)
           (Imp.CmpOpExp (CmpSlt Int32) ind bound)

allThreads :: KernelConstants -> InKernelGen () -> InKernelGen Imp.KernelCode
allThreads constants = ImpGen.subImpM_ $ inKernelOperations constants'
  where constants' =
          constants { kernelThreadActive = Imp.ValueExp (BoolValue True) }

streaming :: KernelConstants -> VName -> SubExp -> InKernelGen () -> InKernelGen Imp.KernelCode
streaming constants chunksize bound m = do
  bound' <- ImpGen.subExpToDimSize bound
  let constants' =
        constants { kernelStreamed = (chunksize, bound') : kernelStreamed constants }
  ImpGen.subImpM_ (inKernelOperations constants') m

compileKernelResult :: KernelConstants -> ImpGen.ValueDestination -> KernelResult
                    -> InKernelGen ()

compileKernelResult constants dest (ThreadsReturn OneResultPerGroup what) = do
  i <- newVName "i"

  in_local_memory <- arrayInLocalMemory what
  let me = Imp.var (kernelLocalThreadId constants) int32

  if not in_local_memory then do
    write_result <-
      ImpGen.collect $
      ImpGen.copyDWIMDest dest [ImpGen.varIndex $ kernelGroupId constants] what []

    who' <- ImpGen.compileSubExp $ intConst Int32 0
    ImpGen.emit $
      Imp.If (Imp.CmpOpExp (CmpEq int32) me who') write_result mempty
    else do
      -- If the result of the group is an array in local memory, we
      -- store it by collective copying among all the threads of the
      -- group.  TODO: also do this if the array is in global memory
      -- (but this is a bit more tricky, synchronisation-wise).
      --
      -- We do the reads/writes multidimensionally, but the loop is
      -- single-dimensional.
      ws <- mapM ImpGen.compileSubExp . arrayDims =<< subExpType what
      -- Compute how many elements this thread is responsible for.
      -- Formula: (w - ltid) / group_size (rounded up).
      let w = product ws
          ltid = ImpGen.varIndex (kernelLocalThreadId constants)
          group_size = Imp.sizeToExp (kernelGroupSize constants)
          to_write = (w - ltid) `quotRoundingUp` group_size
          is = unflattenIndex ws $ ImpGen.varIndex i * group_size + ltid

      write_result <-
        ImpGen.collect $
        ImpGen.copyDWIMDest dest (ImpGen.varIndex (kernelGroupId constants) : is)
                            what is

      ImpGen.emit $ Imp.For i Int32 to_write write_result

compileKernelResult constants dest (ThreadsReturn AllThreads what) =
  ImpGen.copyDWIMDest dest [ImpGen.varIndex $ kernelGlobalThreadId constants] what []

compileKernelResult constants dest (ThreadsReturn (ThreadsPerGroup limit) what) = do
  write_result <-
    ImpGen.collect $
    ImpGen.copyDWIMDest dest [ImpGen.varIndex $ kernelGroupId constants] what []

  ImpGen.emit $ Imp.If (isActive limit) write_result mempty

compileKernelResult constants dest (ThreadsReturn ThreadsInSpace what) = do
  let is = map (ImpGen.varIndex . fst) $ kernelDimensions constants
  write_result <- ImpGen.collect $ ImpGen.copyDWIMDest dest is what []
  ImpGen.emit $ Imp.If (kernelThreadActive constants)
    write_result mempty

compileKernelResult constants dest (ConcatReturns SplitContiguous _ per_thread_elems moffset what) = do
  ImpGen.ArrayDestination (Just dest_loc) <- return dest
  let dest_loc_offset = ImpGen.offsetArray dest_loc offset
      dest' = ImpGen.ArrayDestination $ Just dest_loc_offset
  ImpGen.copyDWIMDest dest' [] (Var what) []
  where offset = case moffset of
                   Nothing -> ImpGen.compileSubExpOfType int32 per_thread_elems *
                              ImpGen.varIndex (kernelGlobalThreadId constants)
                   Just se -> ImpGen.compileSubExpOfType int32 se

compileKernelResult constants dest (ConcatReturns (SplitStrided stride) _ _ moffset what) = do
  ImpGen.ArrayDestination (Just dest_loc) <- return dest
  let dest_loc' = ImpGen.strideArray
                  (ImpGen.offsetArray dest_loc offset) $
                  ImpGen.compileSubExpOfType int32 stride
      dest' = ImpGen.ArrayDestination $ Just dest_loc'
  ImpGen.copyDWIMDest dest' [] (Var what) []
  where offset = case moffset of
                   Nothing -> ImpGen.varIndex (kernelGlobalThreadId constants)
                   Just se -> ImpGen.compileSubExpOfType int32 se

compileKernelResult constants dest (WriteReturn rws _arr dests) = do
  rws' <- mapM ImpGen.compileSubExp rws
  forM_ dests $ \(is, e) -> do
    is' <- mapM ImpGen.compileSubExp is
    let condInBounds0 = Imp.CmpOpExp (Imp.CmpSle Int32) $
                        Imp.ValueExp (IntValue (Int32Value 0))
        condInBounds1 = Imp.CmpOpExp (Imp.CmpSlt Int32)
        condInBounds i rw = Imp.BinOpExp LogAnd (condInBounds0 i) (condInBounds1 i rw)
        write = foldl (Imp.BinOpExp LogAnd) (kernelThreadActive constants) $
                zipWith condInBounds is' rws'
    actual_body' <- ImpGen.collect $
      ImpGen.copyDWIMDest dest (map (ImpGen.compileSubExpOfType int32) is) e []
    ImpGen.emit $ Imp.If write actual_body' Imp.Skip

compileKernelResult _ _ KernelInPlaceReturn{} =
  -- Already in its place... said it was a hack.
  return ()

isActive :: [(VName, SubExp)] -> Imp.Exp
isActive limit = case actives of
                    [] -> Imp.ValueExp $ BoolValue True
                    x:xs -> foldl (Imp.BinOpExp LogAnd) x xs
  where (is, ws) = unzip limit
        actives = zipWith active is $ map (ImpGen.compileSubExpOfType Bool) ws
        active i = Imp.CmpOpExp (CmpSlt Int32) (Imp.var i Bool)

setSpaceIndices :: KernelSpace -> InKernelGen ()
setSpaceIndices space =
  case spaceStructure space of
    FlatThreadSpace is_and_dims ->
      flatSpaceWith gtid is_and_dims
    NestedThreadSpace is_and_dims -> do
      let (gtids, gdims, ltids, ldims) = unzip4 is_and_dims
      gdims' <- mapM ImpGen.compileSubExp gdims
      ldims' <- mapM ImpGen.compileSubExp ldims
      let (gtid_es, ltid_es) = unzip $ unflattenNestedIndex gdims' ldims' gtid
      forM_ (zip gtids gtid_es) $ \(i,e) ->
        ImpGen.emit $ Imp.SetScalar i e
      forM_ (zip ltids ltid_es) $ \(i,e) ->
        ImpGen.emit $ Imp.SetScalar i e
  where gtid = Imp.var (spaceGlobalId space) int32

        flatSpaceWith base is_and_dims = do
          let (is, dims) = unzip is_and_dims
          dims' <- mapM ImpGen.compileSubExp dims
          let index_expressions = unflattenIndex dims' base
          forM_ (zip is index_expressions) $ \(i, x) ->
            ImpGen.emit $ Imp.SetScalar i x

unflattenNestedIndex :: IntegralExp num => [num] -> [num] -> num -> [(num,num)]
unflattenNestedIndex global_dims group_dims global_id =
  zip global_is local_is
  where num_groups_dims = zipWith quotRoundingUp global_dims group_dims
        group_size = product group_dims
        group_id = global_id `Futhark.Util.IntegralExp.quot` group_size
        local_id = global_id `Futhark.Util.IntegralExp.rem` group_size

        group_is = unflattenIndex num_groups_dims group_id
        local_is = unflattenIndex group_dims local_id
        global_is = zipWith (+) local_is $ zipWith (*) group_is group_dims

arrayInLocalMemory :: SubExp -> InKernelGen Bool
arrayInLocalMemory (Var name) = do
  res <- ImpGen.lookupVar name
  case res of
    ImpGen.ArrayVar _ entry ->
      (Space "local"==) . ImpGen.entryMemSpace <$>
      ImpGen.lookupMemory (ImpGen.memLocationName (ImpGen.entryArrayLocation entry))
    _ -> return False
arrayInLocalMemory Constant{} = return False