{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
-- | A representation of nested-parallel in-kernel per-workgroup
-- expressions.
module Futhark.Representation.Kernels.KernelExp
  ( KernelExp(..)
  , GroupStreamLambda(..)
  , SplitOrdering(..)
  , CombineSpace(..)
  , combineSpace
  , scopeOfCombineSpace
  , typeCheckKernelExp
  )
  where

import Control.Monad
import Data.Monoid ((<>))
import Data.Maybe
import qualified Data.Set as S
import qualified Data.Map.Strict as M

import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.Analysis.Range as Range
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Representation.Aliases
import Futhark.Representation.Ranges
import Futhark.Transform.Substitute
import Futhark.Transform.Rename
import Futhark.Optimise.Simplify.Lore
import Futhark.Analysis.Usage
import Futhark.Analysis.Metrics
import qualified Futhark.Analysis.ScalExp as SE
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Util.Pretty
  ((<+>), (</>), ppr, comma, commasep, Pretty, parens, text, apply, braces, annot, indent)
import qualified Futhark.TypeCheck as TC
import Futhark.Util (chunks)

-- | How an array is split into chunks.
data SplitOrdering = SplitContiguous
                   | SplitStrided SubExp
                   deriving (Eq, Ord, Show)

-- | A combine can be fully or partially in-place.  The initial arrays
-- here work like the ones from the Scatter SOAC.
data CombineSpace = CombineSpace { cspaceScatter :: [(SubExp, Int, VName)]
                                 , cspaceDims :: [(VName,SubExp)] }
                  deriving (Eq, Ord, Show)

combineSpace :: [(VName,SubExp)] -> CombineSpace
combineSpace = CombineSpace []

scopeOfCombineSpace :: CombineSpace -> Scope lore
scopeOfCombineSpace (CombineSpace _ dims) =
  M.fromList $ zip (map fst dims) $ repeat $ IndexInfo Int32

data KernelExp lore = SplitSpace SplitOrdering SubExp SubExp SubExp
                      -- ^ @SplitSpace o w i elems_per_thread@.
                      --
                      -- Computes how to divide array elements to
                      -- threads in a kernel.  Returns the number of
                      -- elements in the chunk that the current thread
                      -- should take.
                      --
                      -- @w@ is the length of the outer dimension in
                      -- the array. @i@ is the current thread
                      -- index. Each thread takes at most
                      -- @elems_per_thread@ elements.
                      --
                      -- If the order @o@ is 'SplitContiguous', thread with index @i@
                      -- should receive elements
                      -- @i*elems_per_tread, i*elems_per_thread + 1,
                      -- ..., i*elems_per_thread + (elems_per_thread-1)@.
                      --
                      -- If the order @o@ is @'SplitStrided' stride@,
                      -- the thread will receive elements @i,
                      -- i+stride, i+2*stride, ...,
                      -- i+(elems_per_thread-1)*stride@.
                    | Combine CombineSpace [Type] [(VName,SubExp)] (Body lore)
                      -- ^ @Combine cspace ts aspace body@ will
                      -- combine values from threads to a single
                      -- (multidimensional) array.  If we define @(is,
                      -- ws) = unzip cspace@, then @ws@ is defined the
                      -- same accross all threads.  The @cspace@
                      -- defines the shape of the resulting array, and
                      -- the identifiers used to identify each
                      -- individual element.  Only threads for which
                      -- @all (\(i,w) -> i < w) aspace@ is true will
                      -- provide a value (of type @ts@), which is
                      -- generated by @body@.
                      --
                      -- The result of a combine is always stored in local
                      -- memory (OpenCL terminology)
                      --
                      -- The same thread may be assigned to multiple
                      -- elements of 'Combine', if the size of the
                      -- 'CombineSpace' exceeds the group size.
                    | GroupReduce SubExp
                      (Lambda lore) [(SubExp,VName)]
                      -- ^ @GroupReduce w lam input@ (with @(nes, arrs) = unzip input@),
                      -- will perform a reduction of the arrays @arrs@ using the
                      -- associative reduction operator @lam@ and the neutral
                      -- elements @nes@.
                      --
                      -- The arrays @arrs@ must all have outer
                      -- dimension @w@, which must not be larger than
                      -- the group size.
                      --
                      -- Currently a GroupReduce consumes the input arrays, as
                      -- it uses them for scratch space to store temporary
                      -- results
                      --
                      -- All threads in a group must participate in a
                      -- GroupReduce (due to barriers)
                      --
                      -- The length of the arrays @w@ can be smaller than the
                      -- number of elements in a group (neutral element will be
                      -- filled in), but @w@ can never be larger than the group
                      -- size.
                    | GroupScan SubExp
                      (Lambda lore) [(SubExp,VName)]
                      -- ^ Same restrictions as with 'GroupReduce'.
                    | GroupStream SubExp SubExp
                      (GroupStreamLambda lore) [SubExp] [VName]
                      -- Morally a StreamSeq
                      -- First  SubExp is the outersize of the array
                      -- Second SubExp is the maximal chunk size
                      -- [SubExp] is the accumulator, [VName] are the input arrays
                    | GroupGenReduce [SubExp] [VName] (LambdaT lore) [SubExp] [SubExp] VName
                      -- ^ GroupGenReduce <length> <destarrays> <op> <bucket> <values> <locks arrays>
                    | Barrier [SubExp]
                      -- ^ HACK: Semantically identity, but inserts a
                      -- barrier afterwards.  This reflects a weakness
                      -- in our kernel representation.
                    deriving (Eq, Ord, Show)

data GroupStreamLambda lore = GroupStreamLambda
  { groupStreamChunkSize :: VName
  , groupStreamChunkOffset :: VName
  , groupStreamAccParams :: [LParam lore]
  , groupStreamArrParams :: [LParam lore]
  , groupStreamLambdaBody :: Body lore
  }

deriving instance Annotations lore => Eq (GroupStreamLambda lore)
deriving instance Annotations lore => Show (GroupStreamLambda lore)
deriving instance Annotations lore => Ord (GroupStreamLambda lore)

instance Attributes lore => IsOp (KernelExp lore) where
  safeOp _ = False
  cheapOp _ = True

instance Attributes lore => TypedOp (KernelExp lore) where
  opType SplitSpace{} =
    pure $ staticShapes [Prim int32]
  opType (Combine (CombineSpace scatter cspace) ts _ _) =
    pure $ staticShapes $
    zipWith arrayOfRow val_ts ws ++
    map (`arrayOfShape` shape) (drop (sum ns*2) ts)
    where shape = Shape $ map snd cspace
          val_ts = concatMap (take 1) $ chunks ns $
                   take (sum ns) $ drop (sum ns) ts
          (ws, ns, _) = unzip3 scatter
  opType (GroupReduce _ lam _) =
    pure $ staticShapes $ lambdaReturnType lam
  opType (GroupScan w lam _) =
    pure $ staticShapes $ map (`arrayOfRow` w) (lambdaReturnType lam)
  opType (GroupStream _ _ lam _ _) =
    pure $ staticShapes $ map paramType $ groupStreamAccParams lam
  opType (GroupGenReduce _ dests _ _ _ _) =
    staticShapes <$> traverse lookupType dests
  opType (Barrier ses) = staticShapes <$> traverse subExpType ses

instance FreeIn SplitOrdering where
  freeIn SplitContiguous = mempty
  freeIn (SplitStrided stride) = freeIn stride

instance Attributes lore => FreeIn (KernelExp lore) where
  freeIn (SplitSpace o w i elems_per_thread) =
    freeIn o <> freeIn [w, i, elems_per_thread]
  freeIn (Combine (CombineSpace scatter cspace) ts active body) =
    freeIn scatter <> freeIn (map snd cspace) <> freeIn ts <> freeIn active <> freeInBody body
  freeIn (GroupReduce w lam input) =
    freeIn w <> freeInLambda lam <> freeIn input
  freeIn (GroupScan w lam input) =
    freeIn w <> freeInLambda lam <> freeIn input
  freeIn (GroupStream w maxchunk lam accs arrs) =
    freeIn w <> freeIn maxchunk <> freeIn lam <> freeIn accs <> freeIn arrs
  freeIn (GroupGenReduce w dests op bucket values locks) =
    freeIn w <> freeIn dests <> freeInLambda op <> freeIn bucket <> freeIn values <> freeIn locks
  freeIn (Barrier ses) = freeIn ses

instance Attributes lore => FreeIn (GroupStreamLambda lore) where
  freeIn (GroupStreamLambda chunk_size chunk_offset acc_params arr_params body) =
    freeInBody body `S.difference` bound_here
    where bound_here = S.fromList $
                       chunk_offset : chunk_size :
                       map paramName (acc_params ++ arr_params)

instance Ranged inner => RangedOp (KernelExp inner) where
  opRanges (SplitSpace _ _ _ elems_per_thread) =
    [(Just (ScalarBound 0),
      Just (ScalarBound (SE.subExpToScalExp elems_per_thread int32)))]
  opRanges _ = repeat unknownRange

instance (Attributes lore, Aliased lore) => AliasedOp (KernelExp lore) where
  opAliases SplitSpace{} =
    [mempty]
  opAliases Combine{} =
    [mempty]
  opAliases (GroupReduce _ lam _) =
    replicate (length (lambdaReturnType lam)) mempty
  opAliases (GroupScan _ lam _) =
    replicate (length (lambdaReturnType lam)) mempty
  opAliases (GroupStream _ _ lam _ _) =
    map (const mempty) $ groupStreamAccParams lam
  opAliases (GroupGenReduce _ dests _ _ _ _) =
    map S.singleton dests
  opAliases (Barrier ses) = map subExpAliases ses

  consumedInOp (GroupReduce _ _ input) =
    S.fromList $ map snd input
  consumedInOp (GroupScan _ _ input) =
    S.fromList $ map snd input
  consumedInOp (GroupStream _ _ lam accs arrs) =
    -- GroupStream always consumes array-typed accumulators.  This
    -- guarantees that we can use their storage for the result of the
    -- lambda.
    S.map consumedArray $
    S.fromList (map paramName acc_params) <> consumedInBody body
    where GroupStreamLambda _ _ acc_params arr_params body = lam
          consumedArray v = fromMaybe v $ subExpVar =<< lookup v params_to_arrs
          params_to_arrs = zip (map paramName $ acc_params ++ arr_params) $
                           accs ++ map Var arrs
  consumedInOp (GroupGenReduce _ dests _ _ _ _) =
    S.fromList dests

  consumedInOp SplitSpace{} = mempty
  consumedInOp Barrier{} = mempty
  consumedInOp (Combine _ _ _ body) = consumedInBody body

instance Substitute SplitOrdering where
  substituteNames _ SplitContiguous =
    SplitContiguous
  substituteNames subst (SplitStrided stride) =
    SplitStrided $ substituteNames subst stride

instance Substitute CombineSpace where
  substituteNames substs (CombineSpace scatter dims) =
    CombineSpace (map sub scatter) (substituteNames substs dims)
    where sub (w, n, a) =
            (substituteNames substs w, n, substituteNames substs a)

instance Attributes lore => Substitute (KernelExp lore) where
  substituteNames subst (SplitSpace o w i elems_per_thread) =
    SplitSpace
    (substituteNames subst o)
    (substituteNames subst w)
    (substituteNames subst i)
    (substituteNames subst elems_per_thread)
  substituteNames subst (Combine cspace ts active v) =
    Combine (substituteNames subst cspace) ts
    (substituteNames subst active) (substituteNames subst v)
  substituteNames subst (GroupReduce w lam input) =
    GroupReduce (substituteNames subst w)
    (substituteNames subst lam) (substituteNames subst input)
  substituteNames subst (GroupScan w lam input) =
    GroupScan (substituteNames subst w)
    (substituteNames subst lam) (substituteNames subst input)
  substituteNames subst (GroupStream w maxchunk lam accs arrs) =
    GroupStream
    (substituteNames subst w) (substituteNames subst maxchunk)
    (substituteNames subst lam)
    (substituteNames subst accs) (substituteNames subst arrs)
  substituteNames subst (GroupGenReduce w dests op bucket vs locks) =
    GroupGenReduce (substituteNames subst w) (substituteNames subst dests)
    (substituteNames subst op) (substituteNames subst bucket) (substituteNames subst vs)
    (substituteNames subst locks)
  substituteNames substs (Barrier ses) = Barrier $ substituteNames substs ses

instance Attributes lore => Substitute (GroupStreamLambda lore) where
  substituteNames
    subst (GroupStreamLambda chunk_size chunk_offset acc_params arr_params body) =
    GroupStreamLambda
    (substituteNames subst chunk_size)
    (substituteNames subst chunk_offset)
    (substituteNames subst acc_params)
    (substituteNames subst arr_params)
    (substituteNames subst body)

instance Rename SplitOrdering where
  rename SplitContiguous =
    pure SplitContiguous
  rename (SplitStrided stride) =
    SplitStrided <$> rename stride

instance Rename CombineSpace where
  rename = substituteRename

instance Renameable lore => Rename (KernelExp lore) where
  rename (SplitSpace o w i elems_per_thread) =
    SplitSpace
    <$> rename o
    <*> rename w
    <*> rename i
    <*> rename elems_per_thread
  rename (Combine cspace ts active v) =
    Combine <$> rename cspace <*> rename ts <*> rename active <*> rename v
  rename (GroupReduce w lam input) =
    GroupReduce <$> rename w <*> rename lam <*> rename input
  rename (GroupScan w lam input) =
    GroupScan <$> rename w <*> rename lam <*> rename input
  rename (GroupStream w maxchunk lam accs arrs) =
    GroupStream <$> rename w <*> rename maxchunk <*>
    rename lam <*> rename accs <*> rename arrs
  rename (GroupGenReduce w dests op bucket vs locks) =
    GroupGenReduce <$> rename w <*> rename dests <*> rename op <*>
    rename bucket <*> rename vs <*> rename locks
  rename (Barrier ses) = Barrier <$> mapM rename ses

instance Renameable lore => Rename (GroupStreamLambda lore) where
  rename (GroupStreamLambda chunk_size chunk_offset acc_params arr_params body) =
    bindingForRename (chunk_size : chunk_offset : map paramName (acc_params++arr_params)) $
    GroupStreamLambda <$>
    rename chunk_size <*>
    rename chunk_offset <*>
    rename acc_params <*>
    rename arr_params <*>
    rename body

instance (Attributes lore,
          Attributes (Aliases lore),
          CanBeAliased (Op lore)) => CanBeAliased (KernelExp lore) where
  type OpWithAliases (KernelExp lore) = KernelExp (Aliases lore)

  addOpAliases (SplitSpace o w i elems_per_thread) =
    SplitSpace o w i elems_per_thread
  addOpAliases (GroupReduce w lam input) =
    GroupReduce w (Alias.analyseLambda lam) input
  addOpAliases (GroupScan w lam input) =
    GroupScan w (Alias.analyseLambda lam) input
  addOpAliases (GroupStream w maxchunk lam accs arrs) =
    GroupStream w maxchunk lam' accs arrs
    where lam' = analyseGroupStreamLambda lam
          analyseGroupStreamLambda (GroupStreamLambda chunk_size chunk_offset acc_params arr_params body) =
            GroupStreamLambda chunk_size chunk_offset acc_params arr_params $
            Alias.analyseBody body
  addOpAliases (GroupGenReduce w dests op bucket vs locks) =
    GroupGenReduce w dests (Alias.analyseLambda op) bucket vs locks
  addOpAliases (Combine cspace ts active body) =
    Combine cspace ts active $ Alias.analyseBody body
  addOpAliases (Barrier ses) = Barrier ses

  removeOpAliases (GroupReduce w lam input) =
    GroupReduce w (removeLambdaAliases lam) input
  removeOpAliases (GroupScan w lam input) =
    GroupScan w (removeLambdaAliases lam) input
  removeOpAliases (GroupStream w maxchunk lam accs arrs) =
    GroupStream w maxchunk (removeGroupStreamLambdaAliases lam) accs arrs
    where removeGroupStreamLambdaAliases (GroupStreamLambda chunk_size chunk_offset acc_params arr_params body) =
            GroupStreamLambda chunk_size chunk_offset acc_params arr_params $
            removeBodyAliases body
  removeOpAliases (GroupGenReduce w dests op bucket vs locks) =
    GroupGenReduce w dests (removeLambdaAliases op) bucket vs locks
  removeOpAliases (Combine cspace ts active body) =
    Combine cspace ts active $ removeBodyAliases body
  removeOpAliases (SplitSpace o w i elems_per_thread) =
    SplitSpace o w i elems_per_thread
  removeOpAliases (Barrier ses) = Barrier ses

instance (Attributes lore,
          Attributes (Ranges lore),
          CanBeRanged (Op lore)) => CanBeRanged (KernelExp lore) where
  type OpWithRanges (KernelExp lore) = KernelExp (Ranges lore)

  addOpRanges (SplitSpace o w i elems_per_thread) =
    SplitSpace o w i elems_per_thread
  addOpRanges (GroupReduce w lam input) =
    GroupReduce w (Range.runRangeM $ Range.analyseLambda lam) input
  addOpRanges (GroupScan w lam input) =
    GroupScan w (Range.runRangeM $ Range.analyseLambda lam) input
  addOpRanges (GroupGenReduce w dests op bucket vs locks) =
    GroupGenReduce w dests (Range.runRangeM $ Range.analyseLambda op) bucket vs locks
  addOpRanges (Combine cspace ts active body) =
    Combine cspace ts active $ Range.runRangeM $ Range.analyseBody body
  addOpRanges (GroupStream w maxchunk lam accs arrs) =
    GroupStream w maxchunk lam' accs arrs
    where lam' = analyseGroupStreamLambda lam
          analyseGroupStreamLambda (GroupStreamLambda chunk_size chunk_offset acc_params arr_params body) =
            GroupStreamLambda chunk_size chunk_offset acc_params arr_params $
            Range.runRangeM $ Range.analyseBody body
  addOpRanges (Barrier ses) = Barrier ses

  removeOpRanges (GroupReduce w lam input) =
    GroupReduce w (removeLambdaRanges lam) input
  removeOpRanges (GroupScan w lam input) =
    GroupScan w (removeLambdaRanges lam) input
  removeOpRanges (GroupStream w maxchunk lam accs arrs) =
    GroupStream w maxchunk (removeGroupStreamLambdaRanges lam) accs arrs
    where removeGroupStreamLambdaRanges (GroupStreamLambda chunk_size chunk_offset acc_params arr_params body) =
            GroupStreamLambda chunk_size chunk_offset acc_params arr_params $
            removeBodyRanges body
  removeOpRanges (GroupGenReduce w dests op bucket vs locks) =
    GroupGenReduce w dests (removeLambdaRanges op) bucket vs locks
  removeOpRanges (Combine cspace ts active body) =
    Combine cspace ts active $ removeBodyRanges body
  removeOpRanges (SplitSpace o w i elems_per_thread) =
    SplitSpace o w i elems_per_thread
  removeOpRanges (Barrier ses) = Barrier ses

instance (Attributes lore, CanBeWise (Op lore)) => CanBeWise (KernelExp lore) where
  type OpWithWisdom (KernelExp lore) = KernelExp (Wise lore)

  removeOpWisdom (GroupReduce w lam input) =
    GroupReduce w (removeLambdaWisdom lam) input
  removeOpWisdom (GroupScan w lam input) =
    GroupScan w (removeLambdaWisdom lam) input
  removeOpWisdom (GroupStream w maxchunk lam accs arrs) =
    GroupStream w maxchunk (removeGroupStreamLambdaWisdom lam) accs arrs
    where removeGroupStreamLambdaWisdom
            (GroupStreamLambda chunk_size chunk_offset acc_params arr_params body) =
            GroupStreamLambda chunk_size chunk_offset acc_params arr_params $
            removeBodyWisdom body
  removeOpWisdom (GroupGenReduce w dests op bucket vs locks) =
    GroupGenReduce w dests (removeLambdaWisdom op) bucket vs locks
  removeOpWisdom (Combine cspace ts active body) =
    Combine cspace ts active $ removeBodyWisdom body
  removeOpWisdom (SplitSpace o w i elems_per_thread) =
    SplitSpace o w i elems_per_thread
  removeOpWisdom (Barrier ses) = Barrier ses

instance ST.IndexOp (KernelExp lore) where

instance Aliased lore => UsageInOp (KernelExp lore) where
  usageInOp (Combine cspace _ _ body) =
    mconcat $ map UT.consumedUsage $ S.toList (consumedInBody body) <>
    [ arr | (_, _, arr) <- cspaceScatter cspace ]
  usageInOp _ = mempty

instance OpMetrics (Op lore) => OpMetrics (KernelExp lore) where
  opMetrics SplitSpace{} = seen "SplitSpace"
  opMetrics Combine{} = seen "Combine"
  opMetrics (GroupReduce _ lam _) = inside "GroupReduce" $ lambdaMetrics lam
  opMetrics (GroupScan _ lam _) = inside "GroupScan" $ lambdaMetrics lam
  opMetrics (GroupGenReduce _ _ op _ _ _) = inside "GroupGenReduce" $ lambdaMetrics op
  opMetrics (GroupStream _ _ lam _ _) =
    inside "GroupStream" $ groupStreamLambdaMetrics lam
    where groupStreamLambdaMetrics =
            bodyMetrics . groupStreamLambdaBody
  opMetrics Barrier{} = seen "Barrier"

typeCheckKernelExp :: TC.Checkable lore => KernelExp (Aliases lore) -> TC.TypeM lore ()

typeCheckKernelExp Barrier{} = return ()

typeCheckKernelExp (SplitSpace o w i elems_per_thread) = do
  case o of
    SplitContiguous     -> return ()
    SplitStrided stride -> TC.require [Prim int32] stride
  mapM_ (TC.require [Prim int32]) [w, i, elems_per_thread]

typeCheckKernelExp (Combine cspace@(CombineSpace scatter dims) ts aspace body) = do
  mapM_ (TC.require [Prim int32]) ws
  TC.binding (scopeOfCombineSpace cspace) $ do
    let (_as_ws, as_ns, _as_vs) = unzip3 scatter
        num_scatters = sum as_ns
        ts_is = take num_scatters ts
        ts_vs = take num_scatters $ drop num_scatters ts

    unless (length ts_is == num_scatters && length ts_vs == num_scatters) $
      TC.bad $ TC.TypeError "Combine: inconsistent return type annotation."

    forM_ ts_is $ \ts_i -> unless (Prim int32 == ts_i) $
      TC.bad $ TC.TypeError "Combine: index return type must be i32."

    to_consume <- forM (zip (chunks as_ns ts_vs) scatter) $ \(ts_vs', (aw, _, a)) -> do
      TC.require [Prim int32] aw
      forM_ ts_vs' $ \ts_v -> TC.requireI [ts_v `arrayOfRow` aw] a
      return a
    -- Consume all at once because it is valid to do two scatters to the same array.
    TC.consume . mconcat =<< mapM TC.lookupAliases to_consume

    mapM_ TC.checkType ts
    mapM_ (TC.requireI [Prim int32]) a_is
    mapM_ (TC.require [Prim int32]) a_ws
    TC.checkLambdaBody ts body
  where ws = map snd dims
        (a_is, a_ws) = unzip aspace

typeCheckKernelExp (GroupReduce w lam input) =
  checkScanOrReduce w lam input

typeCheckKernelExp (GroupScan w lam input) =
  checkScanOrReduce w lam input

typeCheckKernelExp (GroupGenReduce ws dests op bucket vs locks) = do
  mapM_ (TC.require [Prim int32]) ws

  mapM_ (TC.require [Prim int32]) bucket

  dest_row_ts <- mapM (fmap (stripArray (length bucket)) . lookupType) dests

  vs_ts <- mapM subExpType vs
  unless (vs_ts == dest_row_ts) $
    TC.bad $ TC.TypeError $ "Destination arrays have type " ++
    pretty dest_row_ts ++ ", but values to write have type " ++ pretty vs_ts

  TC.requireI [Prim int32 `arrayOfShape` Shape ws] locks

  let asArg t = (t, mempty)
  TC.checkLambda op $ map asArg $ dest_row_ts ++ vs_ts

typeCheckKernelExp (GroupStream w maxchunk lam accs arrs) = do
  TC.require [Prim int32] w
  TC.require [Prim int32] maxchunk

  acc_args <- mapM TC.checkArg accs
  arr_args <- TC.checkSOACArrayArgs w arrs

  checkGroupStreamLambda acc_args arr_args
  where GroupStreamLambda block_size _ acc_params arr_params body = lam
        checkGroupStreamLambda acc_args arr_args = do
          unless (map TC.argType acc_args == map paramType acc_params) $
            TC.bad $ TC.TypeError
            "checkGroupStreamLambda: wrong accumulator arguments."

          let arr_block_ts =
                map ((`arrayOfRow` Var block_size) . TC.argType) arr_args
          unless (map paramType arr_params == arr_block_ts) $
            TC.bad $ TC.TypeError
            "checkGroupStreamLambda: wrong array arguments."

          let acc_consumable =
                zip (map paramName acc_params) (map TC.argAliases acc_args)
              arr_consumable =
                zip (map paramName arr_params) (map TC.argAliases arr_args)
              consumable = acc_consumable ++ arr_consumable
          TC.binding (scopeOf lam) $ TC.consumeOnlyParams consumable $ do
            TC.checkLambdaParams acc_params
            TC.checkLambdaParams arr_params
            TC.checkLambdaBody (map TC.argType acc_args) body

checkScanOrReduce :: TC.Checkable lore =>
                     SubExp -> Lambda (Aliases lore) -> [(SubExp, VName)]
                  -> TC.TypeM lore ()
checkScanOrReduce w lam input = do
  TC.require [Prim int32] w
  let (nes, arrs) = unzip input
      asArg t = (t, mempty)
  neargs <- mapM TC.checkArg nes
  arrargs <- TC.checkSOACArrayArgs w arrs
  TC.checkLambda lam $
    map asArg [Prim int32, Prim int32] ++
    map TC.noArgAliases (neargs ++ arrargs)

instance Scoped lore (GroupStreamLambda lore) where
  scopeOf (GroupStreamLambda chunk_size chunk_offset acc_params arr_params _) =
    M.insert chunk_size (IndexInfo Int32) $
    M.insert chunk_offset (IndexInfo Int32) $
    scopeOfLParams (acc_params ++ arr_params)

instance PrettyLore lore => Pretty (KernelExp lore) where
  ppr (SplitSpace o w i elems_per_thread) =
    text "splitSpace" <> suff <>
    parens (commasep [ppr w, ppr i, ppr elems_per_thread])
    where suff = case o of SplitContiguous     -> mempty
                           SplitStrided stride -> text "Strided" <> parens (ppr stride)
  ppr (Combine (CombineSpace scatter cspace) ts active body) =
    text "combine" <>
    apply (map (\(_,n,a) -> text "@" <> ppr (n,a)) scatter ++
           map (\(i,w) -> ppr i <+> text "<" <+> ppr w) cspace ++
           [apply (map ppr ts), ppr active]) <+> text "{" </>
    indent 2 (ppr body) </>
    text "}"
  ppr (GroupReduce w lam input) =
    text "reduce" <> parens (commasep [ppr w,
                                       ppr lam,
                                       braces (commasep $ map ppr nes),
                                       commasep $ map ppr els])
    where (nes,els) = unzip input
  ppr (GroupScan w lam input) =
    text "scan" <> parens (commasep [ppr w,
                                     ppr lam,
                                     braces (commasep $ map ppr nes),
                                     commasep $ map ppr els])
    where (nes,els) = unzip input
  ppr (GroupStream w maxchunk lam accs arrs) =
    text "stream" <>
    parens (ppr w <> comma <+> ppr maxchunk <> comma </>
            ppr lam <> comma </>
            braces (commasep $ map ppr accs) <> comma </>
            commasep (map ppr arrs))

  ppr (GroupGenReduce w dests op bucket vs locks) =
    text "gen_reduce" <>
    parens (ppr w <> comma </>
            braces (commasep $ map ppr dests) <> comma </>
            ppr op <> comma </>
            braces (commasep $ map ppr bucket) <> comma </>
            braces (commasep $ map ppr vs) <> comma </>
            ppr locks)

  ppr (Barrier ses) = text "barrier" <> parens (commasep $ map ppr ses)

instance PrettyLore lore => Pretty (GroupStreamLambda lore) where
  ppr (GroupStreamLambda block_size block_offset acc_params arr_params body) =
    annot (mapMaybe ppAnnot params) $
    text "fn" <+>
    parens (commasep (block_size' : block_offset' : map ppr params)) <+>
    text "=>" </> indent 2 (ppr body)
    where params = acc_params ++ arr_params
          block_size' = text "int" <+> ppr block_size
          block_offset' = text "int" <+> ppr block_offset