{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
-- | An unstructured grab-bag of various tools and inspection
-- functions that didn't really fit anywhere else.
module Futhark.Tools
  (
    module Futhark.Construct

  , redomapToMapAndReduce
  , dissectScrema
  , sequentialStreamWholeArray

  , partitionChunkedFoldParameters

  -- * Primitive expressions
  , module Futhark.Analysis.PrimExp.Convert
  )
where

import Control.Monad.Identity

import Futhark.IR
import Futhark.IR.SOACS.SOAC
import Futhark.MonadFreshNames
import Futhark.Construct
import Futhark.Analysis.PrimExp.Convert
import Futhark.Util

-- | Turns a binding of a @redomap@ into two seperate bindings, a
-- @map@ binding and a @reduce@ binding (returned in that order).
--
-- Reuses the original pattern for the @reduce@, and creates a new
-- pattern with new 'Ident's for the result of the @map@.
--
-- Only handles a pattern with an empty 'patternContextElements'.
redomapToMapAndReduce :: (MonadFreshNames m, Bindable lore,
                          ExpDec lore ~ (), Op lore ~ SOAC lore) =>
                         Pattern lore
                      -> ( SubExp
                         , Commutativity
                         , LambdaT lore, LambdaT lore, [SubExp]
                         , [VName])
                      -> m (Stm lore, Stm lore)
redomapToMapAndReduce :: Pattern lore
-> (SubExp, Commutativity, LambdaT lore, LambdaT lore, [SubExp],
    [VName])
-> m (Stm lore, Stm lore)
redomapToMapAndReduce (Pattern [] [PatElemT (LetDec lore)]
patelems)
                      (SubExp
w, Commutativity
comm, LambdaT lore
redlam, LambdaT lore
map_lam, [SubExp]
accs, [VName]
arrs) = do
  ([Ident]
map_pat, Pattern lore
red_pat, [(SubExp, VName)]
red_args) <-
    [PatElemT (LetDec lore)]
-> SubExp
-> LambdaT lore
-> [SubExp]
-> m ([Ident], Pattern lore, [(SubExp, VName)])
forall dec (m :: * -> *) lore.
(Typed dec, MonadFreshNames m) =>
[PatElemT dec]
-> SubExp
-> LambdaT lore
-> [SubExp]
-> m ([Ident], PatternT dec, [(SubExp, VName)])
splitScanOrRedomap [PatElemT (LetDec lore)]
patelems SubExp
w LambdaT lore
map_lam [SubExp]
accs
  let map_bnd :: Stm lore
map_bnd = [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident]
map_pat (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ Op lore -> Exp lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> Exp lore) -> Op lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm lore -> [VName] -> SOAC lore
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
w (LambdaT lore -> ScremaForm lore
forall lore. Lambda lore -> ScremaForm lore
mapSOAC LambdaT lore
map_lam) [VName]
arrs
      ([SubExp]
nes, [VName]
red_arrs) = [(SubExp, VName)] -> ([SubExp], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, VName)]
red_args
  Stm lore
red_bnd <- Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern lore
red_pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp lore -> Stm lore)
-> (SOAC lore -> Exp lore) -> SOAC lore -> Stm lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC lore -> Exp lore
forall lore. Op lore -> ExpT lore
Op (SOAC lore -> Stm lore) -> m (SOAC lore) -> m (Stm lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
             (SubExp -> ScremaForm lore -> [VName] -> SOAC lore
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
w (ScremaForm lore -> [VName] -> SOAC lore)
-> m (ScremaForm lore) -> m ([VName] -> SOAC lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Reduce lore] -> m (ScremaForm lore)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Reduce lore] -> m (ScremaForm lore)
reduceSOAC [Commutativity -> LambdaT lore -> [SubExp] -> Reduce lore
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Reduce lore
Reduce Commutativity
comm LambdaT lore
redlam [SubExp]
nes] m ([VName] -> SOAC lore) -> m [VName] -> m (SOAC lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName] -> m [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
red_arrs)
  (Stm lore, Stm lore) -> m (Stm lore, Stm lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm lore
map_bnd, Stm lore
red_bnd)
redomapToMapAndReduce Pattern lore
_ (SubExp, Commutativity, LambdaT lore, LambdaT lore, [SubExp],
 [VName])
_ =
  [Char] -> m (Stm lore, Stm lore)
forall a. HasCallStack => [Char] -> a
error [Char]
"redomapToMapAndReduce does not handle a non-empty 'patternContextElements'"

splitScanOrRedomap :: (Typed dec, MonadFreshNames m) =>
                      [PatElemT dec]
                   -> SubExp -> LambdaT lore -> [SubExp]
                   -> m ([Ident], PatternT dec, [(SubExp, VName)])
splitScanOrRedomap :: [PatElemT dec]
-> SubExp
-> LambdaT lore
-> [SubExp]
-> m ([Ident], PatternT dec, [(SubExp, VName)])
splitScanOrRedomap [PatElemT dec]
patelems SubExp
w LambdaT lore
map_lam [SubExp]
accs = do
  let ([PatElemT dec]
acc_patelems, [PatElemT dec]
arr_patelems) = Int -> [PatElemT dec] -> ([PatElemT dec], [PatElemT dec])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) [PatElemT dec]
patelems
      ([Type]
acc_ts, [Type]
_arr_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) ([Type] -> ([Type], [Type])) -> [Type] -> ([Type], [Type])
forall a b. (a -> b) -> a -> b
$ LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT lore
map_lam
  [Ident]
map_accpat <- (PatElemT dec -> Type -> m Ident)
-> [PatElemT dec] -> [Type] -> m [Ident]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElemT dec -> Type -> m Ident
accMapPatElem [PatElemT dec]
acc_patelems [Type]
acc_ts
  [Ident]
map_arrpat <- (PatElemT dec -> m Ident) -> [PatElemT dec] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT dec -> m Ident
arrMapPatElem [PatElemT dec]
arr_patelems
  let map_pat :: [Ident]
map_pat = [Ident]
map_accpat [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ [Ident]
map_arrpat
      red_args :: [(SubExp, VName)]
red_args = [SubExp] -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
accs ([VName] -> [(SubExp, VName)]) -> [VName] -> [(SubExp, VName)]
forall a b. (a -> b) -> a -> b
$ (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
map_accpat
  ([Ident], PatternT dec, [(SubExp, VName)])
-> m ([Ident], PatternT dec, [(SubExp, VName)])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Ident]
map_pat, [PatElemT dec] -> [PatElemT dec] -> PatternT dec
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT dec]
acc_patelems, [(SubExp, VName)]
red_args)
  where
    accMapPatElem :: PatElemT dec -> Type -> m Ident
accMapPatElem PatElemT dec
pe Type
acc_t =
      [Char] -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent (VName -> [Char]
baseString (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_map_acc") (Type -> m Ident) -> Type -> m Ident
forall a b. (a -> b) -> a -> b
$ Type
acc_t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w
    arrMapPatElem :: PatElemT dec -> m Ident
arrMapPatElem = Ident -> m Ident
forall (m :: * -> *) a. Monad m => a -> m a
return (Ident -> m Ident)
-> (PatElemT dec -> Ident) -> PatElemT dec -> m Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT dec -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent

-- | Turn a Screma into a Scanomap (possibly with mapout parts) and a
-- Redomap.  This is used to handle Scremas that are so complicated
-- that we cannot directly generate efficient parallel code for them.
-- In essense, what happens is the opposite of horisontal fusion.
dissectScrema :: (MonadBinder m, Op (Lore m) ~ SOAC (Lore m),
                  Bindable (Lore m)) =>
                 Pattern (Lore m) -> SubExp -> ScremaForm (Lore m) -> [VName]
              -> m ()
dissectScrema :: Pattern (Lore m)
-> SubExp -> ScremaForm (Lore m) -> [VName] -> m ()
dissectScrema Pattern (Lore m)
pat SubExp
w (ScremaForm [Scan (Lore m)]
scans [Reduce (Lore m)]
reds Lambda (Lore m)
map_lam) [VName]
arrs = do
  let num_reds :: Int
num_reds = [Reduce (Lore m)] -> Int
forall lore. [Reduce lore] -> Int
redResults [Reduce (Lore m)]
reds
      num_scans :: Int
num_scans = [Scan (Lore m)] -> Int
forall lore. [Scan lore] -> Int
scanResults [Scan (Lore m)]
scans
      ([VName]
scan_res, [VName]
red_res, [VName]
map_res) =
        Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 Int
num_scans Int
num_reds ([VName] -> ([VName], [VName], [VName]))
-> [VName] -> ([VName], [VName], [VName])
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern (Lore m)
pat

  [VName]
to_red <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_reds (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"to_red"

  let scanomap :: ScremaForm (Lore m)
scanomap = [Scan (Lore m)] -> Lambda (Lore m) -> ScremaForm (Lore m)
forall lore. [Scan lore] -> Lambda lore -> ScremaForm lore
scanomapSOAC [Scan (Lore m)]
scans Lambda (Lore m)
map_lam
  [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames ([VName]
scan_res [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
to_red [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
map_res) (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$
    Op (Lore m) -> Exp (Lore m)
forall lore. Op lore -> ExpT lore
Op (Op (Lore m) -> Exp (Lore m)) -> Op (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm (Lore m) -> [VName] -> SOAC (Lore m)
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
w ScremaForm (Lore m)
scanomap [VName]
arrs

  ScremaForm (Lore m)
reduce <- [Reduce (Lore m)] -> m (ScremaForm (Lore m))
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Reduce lore] -> m (ScremaForm lore)
reduceSOAC [Reduce (Lore m)]
reds
  [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName]
red_res (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op (Lore m) -> Exp (Lore m)
forall lore. Op lore -> ExpT lore
Op (Op (Lore m) -> Exp (Lore m)) -> Op (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm (Lore m) -> [VName] -> SOAC (Lore m)
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
w ScremaForm (Lore m)
reduce [VName]
to_red

-- | Turn a stream SOAC into statements that apply the stream lambda
-- to the entire input.
sequentialStreamWholeArray :: (MonadBinder m, Bindable (Lore m)) =>
                              Pattern (Lore m)
                           -> SubExp -> [SubExp]
                           -> LambdaT (Lore m) -> [VName]
                           -> m ()
sequentialStreamWholeArray :: Pattern (Lore m)
-> SubExp -> [SubExp] -> LambdaT (Lore m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Lore m)
pat SubExp
w [SubExp]
nes LambdaT (Lore m)
lam [VName]
arrs = do
  -- We just set the chunksize to w and inline the lambda body.  There
  -- is no difference between parallel and sequential streams here.
  let (Param Type
chunk_size_param, [Param Type]
fold_params, [Param Type]
arr_params) =
        Int -> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param Type] -> (Param Type, [Param Type], [Param Type]))
-> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ LambdaT (Lore m) -> [LParam (Lore m)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT (Lore m)
lam

  -- The chunk size is the full size of the array.
  [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size_param] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
w

  -- The accumulator parameters are initialised to the neutral element.
  [(Param Type, SubExp)] -> ((Param Type, SubExp) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [SubExp] -> [(Param Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
fold_params [SubExp]
nes) (((Param Type, SubExp) -> m ()) -> m ())
-> ((Param Type, SubExp) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, SubExp
ne) ->
    [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne

  -- Finally, the array parameters are set to the arrays (but reshaped
  -- to make the types work out; this will be simplified rapidly).
  [(Param Type, VName)] -> ((Param Type, VName) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
arr_params [VName]
arrs) (((Param Type, VName) -> m ()) -> m ())
-> ((Param Type, VName) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) ->
    [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
      ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimCoercion ([SubExp] -> ShapeChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) VName
arr

  -- Then we just inline the lambda body.
  (Stm (Lore m) -> m ()) -> Seq (Stm (Lore m)) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Seq (Stm (Lore m)) -> m ()) -> Seq (Stm (Lore m)) -> m ()
forall a b. (a -> b) -> a -> b
$ BodyT (Lore m) -> Seq (Stm (Lore m))
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT (Lore m) -> Seq (Stm (Lore m)))
-> BodyT (Lore m) -> Seq (Stm (Lore m))
forall a b. (a -> b) -> a -> b
$ LambdaT (Lore m) -> BodyT (Lore m)
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT (Lore m)
lam

  -- The number of results in the body matches exactly the size (and
  -- order) of 'pat', so we bind them up here, again with a reshape to
  -- make the types work out.
  [(PatElemT (LetDec (Lore m)), SubExp)]
-> ((PatElemT (LetDec (Lore m)), SubExp) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT (LetDec (Lore m))]
-> [SubExp] -> [(PatElemT (LetDec (Lore m)), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern (Lore m) -> [PatElemT (LetDec (Lore m))]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern (Lore m)
pat) ([SubExp] -> [(PatElemT (LetDec (Lore m)), SubExp)])
-> [SubExp] -> [(PatElemT (LetDec (Lore m)), SubExp)]
forall a b. (a -> b) -> a -> b
$ BodyT (Lore m) -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT (Lore m) -> [SubExp]) -> BodyT (Lore m) -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT (Lore m) -> BodyT (Lore m)
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT (Lore m)
lam) (((PatElemT (LetDec (Lore m)), SubExp) -> m ()) -> m ())
-> ((PatElemT (LetDec (Lore m)), SubExp) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (LetDec (Lore m))
pe, SubExp
se) ->
    case (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec (Lore m)) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT (LetDec (Lore m))
pe, SubExp
se) of
      ([SubExp]
dims, Var VName
v)
        | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
dims ->
            [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec (Lore m)) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec (Lore m))
pe] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimCoercion [SubExp]
dims) VName
v
      ([SubExp], SubExp)
_ -> [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec (Lore m)) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec (Lore m))
pe] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se

-- | Split the parameters of a stream reduction lambda into the chunk
-- size parameter, the accumulator parameters, and the input chunk
-- parameters.  The integer argument is how many accumulators are
-- used.
partitionChunkedFoldParameters :: Int -> [Param dec]
                               -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters :: Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters Int
_ [] =
  [Char] -> (Param dec, [Param dec], [Param dec])
forall a. HasCallStack => [Char] -> a
error [Char]
"partitionChunkedFoldParameters: lambda takes no parameters"
partitionChunkedFoldParameters Int
num_accs (Param dec
chunk_param : [Param dec]
params) =
  let ([Param dec]
acc_params, [Param dec]
arr_params) = Int -> [Param dec] -> ([Param dec], [Param dec])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_accs [Param dec]
params
  in (Param dec
chunk_param, [Param dec]
acc_params, [Param dec]
arr_params)