{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | The code generator cannot handle the array combinators (@map@ and
-- friends), so this module was written to transform them into the
-- equivalent do-loops.  The transformation is currently rather naive,
-- and - it's certainly worth considering when we can express such
-- transformations in-place.
module Futhark.Transform.FirstOrderTransform
  ( transformFunDef,
    transformConsts,
    FirstOrderRep,
    Transformer,
    transformStmRecursively,
    transformLambda,
    transformSOAC,
  )
where

import Control.Monad.Except
import Control.Monad.State
import Data.List (find, zip4)
import qualified Data.Map.Strict as M
import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.IR as AST
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Util (chunks, splitAt3)

-- | The constraints that must hold for a rep in order to be the
-- target of first-order transformation.
type FirstOrderRep rep =
  ( Buildable rep,
    BuilderOps rep,
    LetDec SOACS ~ LetDec rep,
    LParamInfo SOACS ~ LParamInfo rep,
    CanBeAliased (Op rep)
  )

-- | First-order-transform a single function, with the given scope
-- provided by top-level constants.
transformFunDef ::
  (MonadFreshNames m, FirstOrderRep torep) =>
  Scope torep ->
  FunDef SOACS ->
  m (AST.FunDef torep)
transformFunDef :: Scope torep -> FunDef SOACS -> m (FunDef torep)
transformFunDef Scope torep
consts_scope (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType SOACS]
rettype [FParam SOACS]
params BodyT SOACS
body) = do
  (BodyT torep
body', Stms torep
_) <- (VNameSource -> ((BodyT torep, Stms torep), VNameSource))
-> m (BodyT torep, Stms torep)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((BodyT torep, Stms torep), VNameSource))
 -> m (BodyT torep, Stms torep))
-> (VNameSource -> ((BodyT torep, Stms torep), VNameSource))
-> m (BodyT torep, Stms torep)
forall a b. (a -> b) -> a -> b
$ State VNameSource (BodyT torep, Stms torep)
-> VNameSource -> ((BodyT torep, Stms torep), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (BodyT torep, Stms torep)
 -> VNameSource -> ((BodyT torep, Stms torep), VNameSource))
-> State VNameSource (BodyT torep, Stms torep)
-> VNameSource
-> ((BodyT torep, Stms torep), VNameSource)
forall a b. (a -> b) -> a -> b
$ BuilderT torep (StateT VNameSource Identity) (BodyT torep)
-> Scope torep -> State VNameSource (BodyT torep, Stms torep)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT torep (StateT VNameSource Identity) (BodyT torep)
m Scope torep
consts_scope
  FunDef torep -> m (FunDef torep)
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef torep -> m (FunDef torep))
-> FunDef torep -> m (FunDef torep)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [RetType torep]
-> [FParam torep]
-> BodyT torep
-> FunDef torep
forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> BodyT rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType torep]
[RetType SOACS]
rettype [FParam torep]
[FParam SOACS]
params BodyT torep
body'
  where
    m :: BuilderT torep (StateT VNameSource Identity) (BodyT torep)
m = Scope torep
-> BuilderT torep (StateT VNameSource Identity) (BodyT torep)
-> BuilderT torep (StateT VNameSource Identity) (BodyT torep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param DeclType] -> Scope torep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
[FParam SOACS]
params) (BuilderT torep (StateT VNameSource Identity) (BodyT torep)
 -> BuilderT torep (StateT VNameSource Identity) (BodyT torep))
-> BuilderT torep (StateT VNameSource Identity) (BodyT torep)
-> BuilderT torep (StateT VNameSource Identity) (BodyT torep)
forall a b. (a -> b) -> a -> b
$ BodyT SOACS
-> BuilderT
     torep
     (StateT VNameSource Identity)
     (Body (Rep (BuilderT torep (StateT VNameSource Identity))))
forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
BodyT SOACS -> m (Body (Rep m))
transformBody BodyT SOACS
body

-- | First-order-transform these top-level constants.
transformConsts ::
  (MonadFreshNames m, FirstOrderRep torep) =>
  Stms SOACS ->
  m (AST.Stms torep)
transformConsts :: Stms SOACS -> m (Stms torep)
transformConsts Stms SOACS
stms =
  (((), Stms torep) -> Stms torep)
-> m ((), Stms torep) -> m (Stms torep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Stms torep) -> Stms torep
forall a b. (a, b) -> b
snd (m ((), Stms torep) -> m (Stms torep))
-> m ((), Stms torep) -> m (Stms torep)
forall a b. (a -> b) -> a -> b
$ (VNameSource -> (((), Stms torep), VNameSource))
-> m ((), Stms torep)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (((), Stms torep), VNameSource))
 -> m ((), Stms torep))
-> (VNameSource -> (((), Stms torep), VNameSource))
-> m ((), Stms torep)
forall a b. (a -> b) -> a -> b
$ State VNameSource ((), Stms torep)
-> VNameSource -> (((), Stms torep), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource ((), Stms torep)
 -> VNameSource -> (((), Stms torep), VNameSource))
-> State VNameSource ((), Stms torep)
-> VNameSource
-> (((), Stms torep), VNameSource)
forall a b. (a -> b) -> a -> b
$ BuilderT torep (StateT VNameSource Identity) ()
-> Scope torep -> State VNameSource ((), Stms torep)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT torep (StateT VNameSource Identity) ()
m Scope torep
forall a. Monoid a => a
mempty
  where
    m :: BuilderT torep (StateT VNameSource Identity) ()
m = (Stm -> BuilderT torep (StateT VNameSource Identity) ())
-> Stms SOACS -> BuilderT torep (StateT VNameSource Identity) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm -> BuilderT torep (StateT VNameSource Identity) ()
forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Stm -> m ()
transformStmRecursively Stms SOACS
stms

-- | The constraints that a monad must uphold in order to be used for
-- first-order transformation.
type Transformer m =
  ( MonadBuilder m,
    LocalScope (Rep m) m,
    Buildable (Rep m),
    BuilderOps (Rep m),
    LParamInfo SOACS ~ LParamInfo (Rep m),
    CanBeAliased (Op (Rep m))
  )

transformBody ::
  (Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
  Body ->
  m (AST.Body (Rep m))
transformBody :: BodyT SOACS -> m (Body (Rep m))
transformBody (Body () Stms SOACS
stms Result
res) = m Result -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (m Result -> m (Body (Rep m))) -> m Result -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ do
  (Stm -> m ()) -> Stms SOACS -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm -> m ()
forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Stm -> m ()
transformStmRecursively Stms SOACS
stms
  Result -> m Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

-- | First transform any nested t'Body' or t'Lambda' elements, then
-- apply 'transformSOAC' if the expression is a SOAC.
transformStmRecursively ::
  (Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
  Stm ->
  m ()
transformStmRecursively :: Stm -> m ()
transformStmRecursively (Let Pat SOACS
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac)) =
  StmAux () -> m () -> m ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep m) -> SOAC (Rep m) -> m ()
forall (m :: * -> *).
Transformer m =>
Pat (Rep m) -> SOAC (Rep m) -> m ()
transformSOAC Pat (Rep m)
Pat SOACS
pat (SOAC (Rep m) -> m ()) -> m (SOAC (Rep m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOACMapper SOACS (Rep m) m -> SOAC SOACS -> m (SOAC (Rep m))
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper SOACS (Rep m) m
soacTransform Op SOACS
SOAC SOACS
soac
  where
    soacTransform :: SOACMapper SOACS (Rep m) m
soacTransform = SOACMapper Any Any m
forall (m :: * -> *) rep. Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda SOACS -> m (Lambda (Rep m))
mapOnSOACLambda = Lambda SOACS -> m (Lambda (Rep m))
forall (m :: * -> *) rep somerep.
(MonadFreshNames m, Buildable rep, BuilderOps rep,
 LocalScope somerep m, SameScope somerep rep,
 LetDec rep ~ LetDec SOACS, CanBeAliased (Op rep)) =>
Lambda SOACS -> m (Lambda rep)
transformLambda}
transformStmRecursively (Let Pat SOACS
pat StmAux (ExpDec SOACS)
aux ExpT SOACS
e) =
  StmAux () -> m () -> m ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep m) -> ExpT (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep m)
Pat SOACS
pat (ExpT (Rep m) -> m ()) -> m (ExpT (Rep m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Mapper SOACS (Rep m) m -> ExpT SOACS -> m (ExpT (Rep m))
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper SOACS (Rep m) m
transform ExpT SOACS
e
  where
    transform :: Mapper SOACS (Rep m) m
transform =
      Mapper Any Any m
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope (Rep m) -> BodyT SOACS -> m (Body (Rep m))
mapOnBody = \Scope (Rep m)
scope -> Scope (Rep m) -> m (Body (Rep m)) -> m (Body (Rep m))
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope (Rep m)
scope (m (Body (Rep m)) -> m (Body (Rep m)))
-> (BodyT SOACS -> m (Body (Rep m)))
-> BodyT SOACS
-> m (Body (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyT SOACS -> m (Body (Rep m))
forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
BodyT SOACS -> m (Body (Rep m))
transformBody,
          mapOnRetType :: RetType SOACS -> m (RetType (Rep m))
mapOnRetType = RetType SOACS -> m (RetType (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return,
          mapOnBranchType :: BranchType SOACS -> m (BranchType (Rep m))
mapOnBranchType = BranchType SOACS -> m (BranchType (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return,
          mapOnFParam :: FParam SOACS -> m (FParam (Rep m))
mapOnFParam = FParam SOACS -> m (FParam (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return,
          mapOnLParam :: LParam SOACS -> m (LParam (Rep m))
mapOnLParam = LParam SOACS -> m (LParam (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return,
          mapOnOp :: Op SOACS -> m (Op (Rep m))
mapOnOp = [Char] -> SOAC SOACS -> m (Op (Rep m))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled Op in first order transform"
        }

-- Produce scratch "arrays" for the Map and Scan outputs of Screma.
-- "Arrays" is in quotes because some of those may be accumulators.
resultArray :: Transformer m => [VName] -> [Type] -> m [VName]
resultArray :: [VName] -> [Type] -> m [VName]
resultArray [VName]
arrs [Type]
ts = do
  [Type]
arrs_ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
arrs
  let oneArray :: Type -> m VName
oneArray t :: Type
t@Acc {}
        | Just (VName
v, Type
_) <- ((VName, Type) -> Bool) -> [(VName, Type)] -> Maybe (VName, Type)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t) (Type -> Bool) -> ((VName, Type) -> Type) -> (VName, Type) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Type) -> Type
forall a b. (a, b) -> b
snd) ([VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [Type]
arrs_ts) =
          VName -> m VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
      oneArray Type
t =
        [Char] -> ExpT (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"result" (ExpT (Rep m) -> m VName) -> m (ExpT (Rep m)) -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Type -> m (ExpT (Rep m))
forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank Type
t
  (Type -> m VName) -> [Type] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> m VName
oneArray [Type]
ts

-- | Transform a single 'SOAC' into a do-loop.  The body of the lambda
-- is untouched, and may or may not contain further 'SOAC's depending
-- on the given rep.
transformSOAC ::
  Transformer m =>
  AST.Pat (Rep m) ->
  SOAC (Rep m) ->
  m ()
transformSOAC :: Pat (Rep m) -> SOAC (Rep m) -> m ()
transformSOAC Pat (Rep m)
pat (Screma SubExp
w [VName]
arrs form :: ScremaForm (Rep m)
form@(ScremaForm [Scan (Rep m)]
scans [Reduce (Rep m)]
reds Lambda (Rep m)
map_lam)) = do
  -- See Note [Translation of Screma].
  --
  -- Start by combining all the reduction and scan parts into a single
  -- operator
  let Reduce Commutativity
_ Lambda (Rep m)
red_lam [SubExp]
red_nes = [Reduce (Rep m)] -> Reduce (Rep m)
forall rep. Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce (Rep m)]
reds
      Scan Lambda (Rep m)
scan_lam [SubExp]
scan_nes = [Scan (Rep m)] -> Scan (Rep m)
forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan (Rep m)]
scans
      ([Type]
scan_arr_ts, [Type]
_red_ts, [Type]
map_arr_ts) =
        Int -> Int -> [Type] -> ([Type], [Type], [Type])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) ([Type] -> ([Type], [Type], [Type]))
-> [Type] -> ([Type], [Type], [Type])
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm (Rep m) -> [Type]
forall rep. SubExp -> ScremaForm rep -> [Type]
scremaType SubExp
w ScremaForm (Rep m)
form

  [VName]
scan_arrs <- [VName] -> [Type] -> m [VName]
forall (m :: * -> *).
Transformer m =>
[VName] -> [Type] -> m [VName]
resultArray [] [Type]
scan_arr_ts
  [VName]
map_arrs <- [VName] -> [Type] -> m [VName]
forall (m :: * -> *).
Transformer m =>
[VName] -> [Type] -> m [VName]
resultArray [VName]
arrs [Type]
map_arr_ts

  [Param DeclType]
scanacc_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"scanacc" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) ([Type] -> m [Param DeclType]) -> [Type] -> m [Param DeclType]
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Rep m)
scan_lam
  [Param DeclType]
scanout_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"scanout" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Unique) [Type]
scan_arr_ts
  [Param DeclType]
redout_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"redout" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) ([Type] -> m [Param DeclType]) -> [Type] -> m [Param DeclType]
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Rep m)
red_lam
  [Param DeclType]
mapout_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"mapout" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Unique) [Type]
map_arr_ts

  [Type]
arr_ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
arrs
  let paramForAcc :: Type -> Maybe (Param DeclType)
paramForAcc (Acc VName
c Shape
_ [Type]
_ NoUniqueness
_) = (Param DeclType -> Bool)
-> [Param DeclType] -> Maybe (Param DeclType)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (Type -> Bool
f (Type -> Bool)
-> (Param DeclType -> Type) -> Param DeclType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType) [Param DeclType]
mapout_params
        where
          f :: Type -> Bool
f (Acc VName
c2 Shape
_ [Type]
_ NoUniqueness
_) = VName
c VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
c2
          f Type
_ = Bool
False
      paramForAcc Type
_ = Maybe (Param DeclType)
forall a. Maybe a
Nothing

  let merge :: [(Param DeclType, SubExp)]
merge =
        [[(Param DeclType, SubExp)]] -> [(Param DeclType, SubExp)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
          [ [Param DeclType] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
scanacc_params [SubExp]
scan_nes,
            [Param DeclType] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
scanout_params ([SubExp] -> [(Param DeclType, SubExp)])
-> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
scan_arrs,
            [Param DeclType] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
redout_params [SubExp]
red_nes,
            [Param DeclType] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
mapout_params ([SubExp] -> [(Param DeclType, SubExp)])
-> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
map_arrs
          ]
  VName
i <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i"
  let loopform :: LoopForm (Rep m)
loopform = VName
-> IntType
-> SubExp
-> [(LParam (Rep m), VName)]
-> LoopForm (Rep m)
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
Int64 SubExp
w []
      lam_cons :: Names
lam_cons = Lambda (Aliases (Rep m)) -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda (Lambda (Aliases (Rep m)) -> Names)
-> Lambda (Aliases (Rep m)) -> Names
forall a b. (a -> b) -> a -> b
$ AliasTable -> Lambda (Rep m) -> Lambda (Aliases (Rep m))
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
forall a. Monoid a => a
mempty Lambda (Rep m)
map_lam

  Body (Rep m)
loop_body <- Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder
    (Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m)))
-> (Builder (Rep m) (Body (Rep m))
    -> Builder (Rep m) (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m))
-> m (Body (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope (Rep m)
-> Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m))
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param DeclType] -> Scope (Rep m)
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge) Scope (Rep m) -> Scope (Rep m) -> Scope (Rep m)
forall a. Semigroup a => a -> a -> a
<> LoopForm (Rep m) -> Scope (Rep m)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm (Rep m)
loopform)
    (Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ do
      -- Bind the parameters to the lambda.
      [(Param Type, VName, Type)]
-> ((Param Type, VName, Type)
    -> BuilderT (Rep m) (State VNameSource) ())
-> BuilderT (Rep m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [Type] -> [(Param Type, VName, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Lambda (Rep m) -> [LParam (Rep m)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Rep m)
map_lam) [VName]
arrs [Type]
arr_ts) (((Param Type, VName, Type)
  -> BuilderT (Rep m) (State VNameSource) ())
 -> BuilderT (Rep m) (State VNameSource) ())
-> ((Param Type, VName, Type)
    -> BuilderT (Rep m) (State VNameSource) ())
-> BuilderT (Rep m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr, Type
arr_t) ->
        case Type -> Maybe (Param DeclType)
paramForAcc Type
arr_t of
          Just Param DeclType
acc_out_p ->
            [VName]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (ExpT (Rep m) -> BuilderT (Rep m) (State VNameSource) ())
-> (BasicOp -> ExpT (Rep m))
-> BasicOp
-> BuilderT (Rep m) (State VNameSource) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> BuilderT (Rep m) (State VNameSource) ())
-> BasicOp -> BuilderT (Rep m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
              SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
acc_out_p
          Maybe (Param DeclType)
Nothing
            | Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p VName -> Names -> Bool
`nameIn` Names
lam_cons -> do
              VName
p' <-
                [Char]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p)) (Exp (Rep (BuilderT (Rep m) (State VNameSource)))
 -> BuilderT (Rep m) (State VNameSource) VName)
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
                  BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m)) -> BasicOp -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$
                    VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i]
              [VName]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Rep (BuilderT (Rep m) (State VNameSource)))
 -> BuilderT (Rep m) (State VNameSource) ())
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m)) -> BasicOp -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
p'
            | Bool
otherwise ->
              [VName]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Rep (BuilderT (Rep m) (State VNameSource)))
 -> BuilderT (Rep m) (State VNameSource) ())
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
                BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m)) -> BasicOp -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$
                  VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i]

      -- Insert the statements of the lambda.  We have taken care to
      -- ensure that the parameters are bound at this point.
      (Stm (Rep m) -> BuilderT (Rep m) (State VNameSource) ())
-> Seq (Stm (Rep m)) -> BuilderT (Rep m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm (Rep m) -> BuilderT (Rep m) (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Seq (Stm (Rep m)) -> BuilderT (Rep m) (State VNameSource) ())
-> Seq (Stm (Rep m)) -> BuilderT (Rep m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Body (Rep m) -> Seq (Stm (Rep m))
forall rep. BodyT rep -> Stms rep
bodyStms (Body (Rep m) -> Seq (Stm (Rep m)))
-> Body (Rep m) -> Seq (Stm (Rep m))
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> Body (Rep m)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Rep m)
map_lam
      -- Split into scan results, reduce results, and map results.
      let (Result
scan_res, Result
red_res, Result
map_res) =
            Int -> Int -> Result -> (Result, Result, Result)
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) (Result -> (Result, Result, Result))
-> Result -> (Result, Result, Result)
forall a b. (a -> b) -> a -> b
$
              Body (Rep m) -> Result
forall rep. BodyT rep -> Result
bodyResult (Body (Rep m) -> Result) -> Body (Rep m) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> Body (Rep m)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Rep m)
map_lam

      Result
scan_res' <-
        Lambda (Rep (BuilderT (Rep m) (State VNameSource)))
-> [BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
-> BuilderT (Rep m) (State VNameSource) Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep m)
Lambda (Rep (BuilderT (Rep m) (State VNameSource)))
scan_lam ([BuilderT
    (Rep m)
    (State VNameSource)
    (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
 -> BuilderT (Rep m) (State VNameSource) Result)
-> [BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
-> BuilderT (Rep m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
          (SubExp -> BuilderT (Rep m) (State VNameSource) (ExpT (Rep m)))
-> [SubExp]
-> [BuilderT (Rep m) (State VNameSource) (ExpT (Rep m))]
forall a b. (a -> b) -> [a] -> [b]
map (ExpT (Rep m) -> BuilderT (Rep m) (State VNameSource) (ExpT (Rep m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT (Rep m)
 -> BuilderT (Rep m) (State VNameSource) (ExpT (Rep m)))
-> (SubExp -> ExpT (Rep m))
-> SubExp
-> BuilderT (Rep m) (State VNameSource) (ExpT (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) ([SubExp] -> [BuilderT (Rep m) (State VNameSource) (ExpT (Rep m))])
-> [SubExp]
-> [BuilderT (Rep m) (State VNameSource) (ExpT (Rep m))]
forall a b. (a -> b) -> a -> b
$
            (Param DeclType -> SubExp) -> [Param DeclType] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param DeclType -> VName) -> Param DeclType -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> VName
forall dec. Param dec -> VName
paramName) [Param DeclType]
scanacc_params [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
scan_res
      Result
red_res' <-
        Lambda (Rep (BuilderT (Rep m) (State VNameSource)))
-> [BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
-> BuilderT (Rep m) (State VNameSource) Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep m)
Lambda (Rep (BuilderT (Rep m) (State VNameSource)))
red_lam ([BuilderT
    (Rep m)
    (State VNameSource)
    (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
 -> BuilderT (Rep m) (State VNameSource) Result)
-> [BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
-> BuilderT (Rep m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
          (SubExp -> BuilderT (Rep m) (State VNameSource) (ExpT (Rep m)))
-> [SubExp]
-> [BuilderT (Rep m) (State VNameSource) (ExpT (Rep m))]
forall a b. (a -> b) -> [a] -> [b]
map (ExpT (Rep m) -> BuilderT (Rep m) (State VNameSource) (ExpT (Rep m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT (Rep m)
 -> BuilderT (Rep m) (State VNameSource) (ExpT (Rep m)))
-> (SubExp -> ExpT (Rep m))
-> SubExp
-> BuilderT (Rep m) (State VNameSource) (ExpT (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) ([SubExp] -> [BuilderT (Rep m) (State VNameSource) (ExpT (Rep m))])
-> [SubExp]
-> [BuilderT (Rep m) (State VNameSource) (ExpT (Rep m))]
forall a b. (a -> b) -> a -> b
$
            (Param DeclType -> SubExp) -> [Param DeclType] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param DeclType -> VName) -> Param DeclType -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> VName
forall dec. Param dec -> VName
paramName) [Param DeclType]
redout_params [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
red_res

      -- Write the scan accumulator to the scan result arrays.
      [VName]
scan_outarrs <-
        Certs
-> BuilderT (Rep m) (State VNameSource) [VName]
-> BuilderT (Rep m) (State VNameSource) [VName]
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying ((SubExpRes -> Certs) -> Result -> Certs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts Result
scan_res) (BuilderT (Rep m) (State VNameSource) [VName]
 -> BuilderT (Rep m) (State VNameSource) [VName])
-> BuilderT (Rep m) (State VNameSource) [VName]
-> BuilderT (Rep m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
          [VName]
-> SubExp
-> [SubExp]
-> BuilderT (Rep m) (State VNameSource) [VName]
forall (m :: * -> *).
Transformer m =>
[VName] -> SubExp -> [SubExp] -> m [VName]
letwith ((Param DeclType -> VName) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> VName
forall dec. Param dec -> VName
paramName [Param DeclType]
scanout_params) (VName -> SubExp
Var VName
i) ([SubExp] -> BuilderT (Rep m) (State VNameSource) [VName])
-> [SubExp] -> BuilderT (Rep m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
scan_res'

      -- Write the map results to the map result arrays.
      [VName]
map_outarrs <-
        Certs
-> BuilderT (Rep m) (State VNameSource) [VName]
-> BuilderT (Rep m) (State VNameSource) [VName]
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying ((SubExpRes -> Certs) -> Result -> Certs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts Result
map_res) (BuilderT (Rep m) (State VNameSource) [VName]
 -> BuilderT (Rep m) (State VNameSource) [VName])
-> BuilderT (Rep m) (State VNameSource) [VName]
-> BuilderT (Rep m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
          [VName]
-> SubExp
-> [SubExp]
-> BuilderT (Rep m) (State VNameSource) [VName]
forall (m :: * -> *).
Transformer m =>
[VName] -> SubExp -> [SubExp] -> m [VName]
letwith ((Param DeclType -> VName) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> VName
forall dec. Param dec -> VName
paramName [Param DeclType]
mapout_params) (VName -> SubExp
Var VName
i) ([SubExp] -> BuilderT (Rep m) (State VNameSource) [VName])
-> [SubExp] -> BuilderT (Rep m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
map_res

      Body (Rep m) -> Builder (Rep m) (Body (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Rep m) -> Builder (Rep m) (Body (Rep m)))
-> ([Result] -> Body (Rep m))
-> [Result]
-> Builder (Rep m) (Body (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seq (Stm (Rep m)) -> Result -> Body (Rep m)
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Seq (Stm (Rep m))
forall a. Monoid a => a
mempty (Result -> Body (Rep m))
-> ([Result] -> Result) -> [Result] -> Body (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([Result] -> Builder (Rep m) (Body (Rep m)))
-> [Result] -> Builder (Rep m) (Body (Rep m))
forall a b. (a -> b) -> a -> b
$
        [ Result
scan_res',
          [VName] -> Result
varsRes [VName]
scan_outarrs,
          Result
red_res',
          [VName] -> Result
varsRes [VName]
map_outarrs
        ]

  -- We need to discard the final scan accumulators, as they are not
  -- bound in the original pattern.
  [VName]
names <-
    ([VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ Pat (Rep m) -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat (Rep m)
pat)
      ([VName] -> [VName]) -> m [VName] -> m [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([Param DeclType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param DeclType]
scanacc_params) ([Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"discard")
  [VName] -> ExpT (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
names (ExpT (Rep m) -> m ()) -> ExpT (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Rep m), SubExp)]
-> LoopForm (Rep m) -> Body (Rep m) -> ExpT (Rep m)
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(Param DeclType, SubExp)]
[(FParam (Rep m), SubExp)]
merge LoopForm (Rep m)
loopform Body (Rep m)
loop_body
transformSOAC Pat (Rep m)
pat (Stream SubExp
w [VName]
arrs StreamForm (Rep m)
_ [SubExp]
nes Lambda (Rep m)
lam) = do
  -- Create a loop that repeatedly applies the lambda body to a
  -- chunksize of 1.  Hopefully this will lead to this outer loop
  -- being the only one, as all the innermost one can be simplified
  -- array (as they will have one iteration each).
  let (Param Type
chunk_size_param, [Param Type]
fold_params, [Param Type]
chunk_params) =
        Int -> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param Type] -> (Param Type, [Param Type], [Param Type]))
-> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> [LParam (Rep m)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam

  [(Param DeclType, SubExp)]
mapout_merge <- [Type]
-> (Type -> m (Param DeclType, SubExp))
-> m [(Param DeclType, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Rep m)
lam) ((Type -> m (Param DeclType, SubExp))
 -> m [(Param DeclType, SubExp)])
-> (Type -> m (Param DeclType, SubExp))
-> m [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ \Type
t ->
    let t' :: Type
t' = Type
t Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
        scratch :: ExpT (Rep m)
scratch = BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m)) -> BasicOp -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t') (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t')
     in (,)
          (Param DeclType -> SubExp -> (Param DeclType, SubExp))
-> m (Param DeclType) -> m (SubExp -> (Param DeclType, SubExp))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"stream_mapout" (Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
t' Uniqueness
Unique)
          m (SubExp -> (Param DeclType, SubExp))
-> m SubExp -> m (Param DeclType, SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Char] -> ExpT (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"stream_mapout_scratch" ExpT (Rep m)
scratch

  -- We need to copy the neutral elements because they may be consumed
  -- in the body of the Stream.
  let copyIfArray :: SubExp -> m SubExp
copyIfArray SubExp
se = do
        Type
se_t <- SubExp -> m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
        case (Type
se_t, SubExp
se) of
          (Array {}, Var VName
v) -> [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
v) (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
          (Type, SubExp)
_ -> SubExp -> m SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
  [SubExp]
nes' <- (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> m SubExp
forall (m :: * -> *). MonadBuilder m => SubExp -> m SubExp
copyIfArray [SubExp]
nes

  let onType :: TypeBase shape NoUniqueness -> TypeBase shape Uniqueness
onType TypeBase shape NoUniqueness
t = TypeBase shape NoUniqueness
t TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
`toDecl` Uniqueness
Unique
      merge :: [(Param DeclType, SubExp)]
merge = [Param DeclType] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param Type -> Param DeclType) -> [Param Type] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> DeclType) -> Param Type -> Param DeclType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> DeclType
forall shape.
TypeBase shape NoUniqueness -> TypeBase shape Uniqueness
onType) [Param Type]
fold_params) [SubExp]
nes' [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
mapout_merge
      merge_params :: [Param DeclType]
merge_params = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge
      mapout_params :: [Param DeclType]
mapout_params = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
mapout_merge

  VName
i <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i"

  let loop_form :: LoopForm (Rep m)
loop_form = VName
-> IntType
-> SubExp
-> [(LParam (Rep m), VName)]
-> LoopForm (Rep m)
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
Int64 SubExp
w []

  [VName] -> ExpT (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size_param] (ExpT (Rep m) -> m ()) -> ExpT (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$
    BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m)) -> BasicOp -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1

  Body (Rep m)
loop_body <- Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$
    Scope (Rep m)
-> Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m))
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm (Rep m) -> Scope (Rep m)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm (Rep m)
loop_form Scope (Rep m) -> Scope (Rep m) -> Scope (Rep m)
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope (Rep m)
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
merge_params) (Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ do
      let slice :: [DimIndex SubExp]
slice = [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (VName -> SubExp
Var VName
i) (VName -> SubExp
Var (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size_param)) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
      [(Param Type, VName)]
-> ((Param Type, VName) -> BuilderT (Rep m) (State VNameSource) ())
-> BuilderT (Rep m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
chunk_params [VName]
arrs) (((Param Type, VName) -> BuilderT (Rep m) (State VNameSource) ())
 -> BuilderT (Rep m) (State VNameSource) ())
-> ((Param Type, VName) -> BuilderT (Rep m) (State VNameSource) ())
-> BuilderT (Rep m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) ->
        [VName]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (ExpT (Rep m) -> BuilderT (Rep m) (State VNameSource) ())
-> (Slice SubExp -> ExpT (Rep m))
-> Slice SubExp
-> BuilderT (Rep m) (State VNameSource) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m))
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> ExpT (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BuilderT (Rep m) (State VNameSource) ())
-> Slice SubExp -> BuilderT (Rep m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
          Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) [DimIndex SubExp]
slice

      (Result
res, Result
mapout_res) <- Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) (Result -> (Result, Result))
-> BuilderT (Rep m) (State VNameSource) Result
-> BuilderT (Rep m) (State VNameSource) (Result, Result)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Lambda (Rep m) -> Body (Rep m)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Rep m)
lam)

      [SubExp]
mapout_res' <- [(Param DeclType, SubExpRes)]
-> ((Param DeclType, SubExpRes)
    -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BuilderT (Rep m) (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param DeclType] -> Result -> [(Param DeclType, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
mapout_params Result
mapout_res) (((Param DeclType, SubExpRes)
  -> BuilderT (Rep m) (State VNameSource) SubExp)
 -> BuilderT (Rep m) (State VNameSource) [SubExp])
-> ((Param DeclType, SubExpRes)
    -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BuilderT (Rep m) (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$ \(Param DeclType
p, SubExpRes Certs
cs SubExp
se) ->
        Certs
-> BuilderT (Rep m) (State VNameSource) SubExp
-> BuilderT (Rep m) (State VNameSource) SubExp
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (BuilderT (Rep m) (State VNameSource) SubExp
 -> BuilderT (Rep m) (State VNameSource) SubExp)
-> (BasicOp -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BasicOp
-> BuilderT (Rep m) (State VNameSource) SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"mapout_res" (ExpT (Rep m) -> BuilderT (Rep m) (State VNameSource) SubExp)
-> (BasicOp -> ExpT (Rep m))
-> BasicOp
-> BuilderT (Rep m) (State VNameSource) SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BasicOp -> BuilderT (Rep m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
          Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p) (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param DeclType
p) [DimIndex SubExp]
slice) SubExp
se

      Stms (Rep (BuilderT (Rep m) (State VNameSource)))
-> Result
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Body (Rep (BuilderT (Rep m) (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep (BuilderT (Rep m) (State VNameSource)))
forall a. Monoid a => a
mempty (Result
 -> BuilderT
      (Rep m)
      (State VNameSource)
      (Body (Rep (BuilderT (Rep m) (State VNameSource)))))
-> Result
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Body (Rep (BuilderT (Rep m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Result
res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ [SubExp] -> Result
subExpsRes [SubExp]
mapout_res'

  Pat (Rep m) -> ExpT (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep m)
pat (ExpT (Rep m) -> m ()) -> ExpT (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Rep m), SubExp)]
-> LoopForm (Rep m) -> Body (Rep m) -> ExpT (Rep m)
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(Param DeclType, SubExp)]
[(FParam (Rep m), SubExp)]
merge LoopForm (Rep m)
loop_form Body (Rep m)
loop_body
transformSOAC Pat (Rep m)
pat (Scatter SubExp
len Lambda (Rep m)
lam [VName]
ivs [(Shape, Int, VName)]
as) = do
  VName
iter <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"write_iter"

  let ([Shape]
as_ws, [Int]
as_ns, [VName]
as_vs) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
as
  [Type]
ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
as_vs
  [Ident]
asOuts <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent [Char]
"write_out") [Type]
ts

  -- Scatter is in-place, so we use the input array as the output array.
  let merge :: [(Param DeclType, SubExp)]
merge = [Ident] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge [Ident]
asOuts ([SubExp] -> [(Param DeclType, SubExp)])
-> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
as_vs
  Body (Rep m)
loopBody <- Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$
    Scope (Rep m)
-> Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m))
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (VName -> NameInfo (Rep m) -> Scope (Rep m) -> Scope (Rep m)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
iter (IntType -> NameInfo (Rep m)
forall rep. IntType -> NameInfo rep
IndexName IntType
Int64) (Scope (Rep m) -> Scope (Rep m)) -> Scope (Rep m) -> Scope (Rep m)
forall a b. (a -> b) -> a -> b
$ [Param DeclType] -> Scope (Rep m)
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams ([Param DeclType] -> Scope (Rep m))
-> [Param DeclType] -> Scope (Rep m)
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge) (Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ do
      [SubExp]
ivs' <- [VName]
-> (VName -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BuilderT (Rep m) (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ivs ((VName -> BuilderT (Rep m) (State VNameSource) SubExp)
 -> BuilderT (Rep m) (State VNameSource) [SubExp])
-> (VName -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BuilderT (Rep m) (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
iv -> do
        Type
iv_t <- VName -> BuilderT (Rep m) (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
iv
        [Char]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"write_iv" (Exp (Rep (BuilderT (Rep m) (State VNameSource)))
 -> BuilderT (Rep m) (State VNameSource) SubExp)
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m)) -> BasicOp -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
iv (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
iv_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
iter]
      Result
ivs'' <- Lambda (Rep (BuilderT (Rep m) (State VNameSource)))
-> [Exp (Rep (BuilderT (Rep m) (State VNameSource)))]
-> BuilderT (Rep m) (State VNameSource) Result
forall (m :: * -> *).
Transformer m =>
Lambda (Rep m) -> [Exp (Rep m)] -> m Result
bindLambda Lambda (Rep m)
Lambda (Rep (BuilderT (Rep m) (State VNameSource)))
lam ((SubExp -> ExpT (Rep m)) -> [SubExp] -> [ExpT (Rep m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) [SubExp]
ivs')

      let indexes :: [(Shape, VName, [(Result, SubExpRes)])]
indexes = [(Shape, Int, VName)]
-> Result -> [(Shape, VName, [(Result, SubExpRes)])]
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults ([Shape] -> [Int] -> [VName] -> [(Shape, Int, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Shape]
as_ws [Int]
as_ns ([VName] -> [(Shape, Int, VName)])
-> [VName] -> [(Shape, Int, VName)]
forall a b. (a -> b) -> a -> b
$ (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
asOuts) Result
ivs''

      [VName]
ress <- [(Shape, VName, [(Result, SubExpRes)])]
-> ((Shape, VName, [(Result, SubExpRes)])
    -> BuilderT (Rep m) (State VNameSource) VName)
-> BuilderT (Rep m) (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Shape, VName, [(Result, SubExpRes)])]
indexes (((Shape, VName, [(Result, SubExpRes)])
  -> BuilderT (Rep m) (State VNameSource) VName)
 -> BuilderT (Rep m) (State VNameSource) [VName])
-> ((Shape, VName, [(Result, SubExpRes)])
    -> BuilderT (Rep m) (State VNameSource) VName)
-> BuilderT (Rep m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(Shape
_, VName
arr, [(Result, SubExpRes)]
indexes') -> do
        Type
arr_t <- VName -> BuilderT (Rep m) (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
        let saveInArray :: VName
-> (Result, SubExpRes)
-> BuilderT (Rep m) (State VNameSource) VName
saveInArray VName
arr' (Result
indexCur, SubExpRes Certs
value_cs SubExp
valueCur) =
              Certs
-> BuilderT (Rep m) (State VNameSource) VName
-> BuilderT (Rep m) (State VNameSource) VName
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying ((SubExpRes -> Certs) -> Result -> Certs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts Result
indexCur Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
value_cs) (BuilderT (Rep m) (State VNameSource) VName
 -> BuilderT (Rep m) (State VNameSource) VName)
-> (ExpT (Rep m) -> BuilderT (Rep m) (State VNameSource) VName)
-> ExpT (Rep m)
-> BuilderT (Rep m) (State VNameSource) VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"write_out" (ExpT (Rep m) -> BuilderT (Rep m) (State VNameSource) VName)
-> ExpT (Rep m) -> BuilderT (Rep m) (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
                BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m)) -> BasicOp -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Safe VName
arr' (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> DimIndex SubExp) -> Result -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (SubExpRes -> SubExp) -> SubExpRes -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
indexCur) SubExp
valueCur

        (VName
 -> (Result, SubExpRes)
 -> BuilderT (Rep m) (State VNameSource) VName)
-> VName
-> [(Result, SubExpRes)]
-> BuilderT (Rep m) (State VNameSource) VName
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM VName
-> (Result, SubExpRes)
-> BuilderT (Rep m) (State VNameSource) VName
saveInArray VName
arr [(Result, SubExpRes)]
indexes'
      Body (Rep m) -> Builder (Rep m) (Body (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Rep m) -> Builder (Rep m) (Body (Rep m)))
-> Body (Rep m) -> Builder (Rep m) (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Body (Rep m)
forall rep. Buildable rep => [SubExp] -> Body rep
resultBody ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
ress)
  Pat (Rep m) -> ExpT (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep m)
pat (ExpT (Rep m) -> m ()) -> ExpT (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Rep m), SubExp)]
-> LoopForm (Rep m) -> Body (Rep m) -> ExpT (Rep m)
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(Param DeclType, SubExp)]
[(FParam (Rep m), SubExp)]
merge (VName
-> IntType
-> SubExp
-> [(LParam (Rep m), VName)]
-> LoopForm (Rep m)
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
iter IntType
Int64 SubExp
len []) Body (Rep m)
loopBody
transformSOAC Pat (Rep m)
pat (Hist SubExp
len [HistOp (Rep m)]
ops Lambda (Rep m)
bucket_fun [VName]
imgs) = do
  VName
iter <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"iter"

  -- Bind arguments to parameters for the merge-variables.
  [Type]
hists_ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType ([VName] -> m [Type]) -> [VName] -> m [Type]
forall a b. (a -> b) -> a -> b
$ (HistOp (Rep m) -> [VName]) -> [HistOp (Rep m)] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp (Rep m) -> [VName]
forall rep. HistOp rep -> [VName]
histDest [HistOp (Rep m)]
ops
  [Ident]
hists_out <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent [Char]
"dests") [Type]
hists_ts
  let merge :: [(Param DeclType, SubExp)]
merge = [Ident] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge [Ident]
hists_out ([SubExp] -> [(Param DeclType, SubExp)])
-> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (HistOp (Rep m) -> [SubExp]) -> [HistOp (Rep m)] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp])
-> (HistOp (Rep m) -> [VName]) -> HistOp (Rep m) -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Rep m) -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp (Rep m)]
ops

  -- Bind lambda-bodies for operators.
  let iter_scope :: Scope (Rep m)
iter_scope = VName -> NameInfo (Rep m) -> Scope (Rep m) -> Scope (Rep m)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
iter (IntType -> NameInfo (Rep m)
forall rep. IntType -> NameInfo rep
IndexName IntType
Int64) (Scope (Rep m) -> Scope (Rep m)) -> Scope (Rep m) -> Scope (Rep m)
forall a b. (a -> b) -> a -> b
$ [Param DeclType] -> Scope (Rep m)
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams ([Param DeclType] -> Scope (Rep m))
-> [Param DeclType] -> Scope (Rep m)
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge
  Body (Rep m)
loopBody <- Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m)))
-> (Builder (Rep m) (Body (Rep m))
    -> Builder (Rep m) (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m))
-> m (Body (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope (Rep m)
-> Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m))
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope (Rep m)
iter_scope (Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ do
    -- Bind images to parameters of bucket function.
    [SubExp]
imgs' <- [VName]
-> (VName -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BuilderT (Rep m) (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
imgs ((VName -> BuilderT (Rep m) (State VNameSource) SubExp)
 -> BuilderT (Rep m) (State VNameSource) [SubExp])
-> (VName -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BuilderT (Rep m) (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
img -> do
      Type
img_t <- VName -> BuilderT (Rep m) (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
img
      [Char]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"pixel" (Exp (Rep (BuilderT (Rep m) (State VNameSource)))
 -> BuilderT (Rep m) (State VNameSource) SubExp)
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m)) -> BasicOp -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
img (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
img_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
iter]
    [SubExp]
imgs'' <- (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [SubExp])
-> BuilderT (Rep m) (State VNameSource) Result
-> BuilderT (Rep m) (State VNameSource) [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda (Rep (BuilderT (Rep m) (State VNameSource)))
-> [Exp (Rep (BuilderT (Rep m) (State VNameSource)))]
-> BuilderT (Rep m) (State VNameSource) Result
forall (m :: * -> *).
Transformer m =>
Lambda (Rep m) -> [Exp (Rep m)] -> m Result
bindLambda Lambda (Rep m)
Lambda (Rep (BuilderT (Rep m) (State VNameSource)))
bucket_fun ((SubExp -> ExpT (Rep m)) -> [SubExp] -> [ExpT (Rep m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) [SubExp]
imgs')

    -- Split out values from bucket function.
    let lens :: Int
lens = [HistOp (Rep m)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp (Rep m)]
ops
        inds :: [SubExp]
inds = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
lens [SubExp]
imgs''
        vals :: [[SubExp]]
vals = [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp (Rep m) -> Int) -> [HistOp (Rep m)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int)
-> (HistOp (Rep m) -> [Type]) -> HistOp (Rep m) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Rep m) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (Lambda (Rep m) -> [Type])
-> (HistOp (Rep m) -> Lambda (Rep m)) -> HistOp (Rep m) -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Rep m) -> Lambda (Rep m)
forall rep. HistOp rep -> Lambda rep
histOp) [HistOp (Rep m)]
ops) ([SubExp] -> [[SubExp]]) -> [SubExp] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
lens [SubExp]
imgs''
        hists_out' :: [[VName]]
hists_out' =
          [Int] -> [VName] -> [[VName]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp (Rep m) -> Int) -> [HistOp (Rep m)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int)
-> (HistOp (Rep m) -> [Type]) -> HistOp (Rep m) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Rep m) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (Lambda (Rep m) -> [Type])
-> (HistOp (Rep m) -> Lambda (Rep m)) -> HistOp (Rep m) -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Rep m) -> Lambda (Rep m)
forall rep. HistOp rep -> Lambda rep
histOp) [HistOp (Rep m)]
ops) ([VName] -> [[VName]]) -> [VName] -> [[VName]]
forall a b. (a -> b) -> a -> b
$
            (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
hists_out

    [[VName]]
hists_out'' <- [([VName], HistOp (Rep m), SubExp, [SubExp])]
-> (([VName], HistOp (Rep m), SubExp, [SubExp])
    -> BuilderT (Rep m) (State VNameSource) [VName])
-> BuilderT (Rep m) (State VNameSource) [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([[VName]]
-> [HistOp (Rep m)]
-> [SubExp]
-> [[SubExp]]
-> [([VName], HistOp (Rep m), SubExp, [SubExp])]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [[VName]]
hists_out' [HistOp (Rep m)]
ops [SubExp]
inds [[SubExp]]
vals) ((([VName], HistOp (Rep m), SubExp, [SubExp])
  -> BuilderT (Rep m) (State VNameSource) [VName])
 -> BuilderT (Rep m) (State VNameSource) [[VName]])
-> (([VName], HistOp (Rep m), SubExp, [SubExp])
    -> BuilderT (Rep m) (State VNameSource) [VName])
-> BuilderT (Rep m) (State VNameSource) [[VName]]
forall a b. (a -> b) -> a -> b
$ \([VName]
hist, HistOp (Rep m)
op, SubExp
idx, [SubExp]
val) -> do
      -- Check whether the indexes are in-bound.  If they are not, we
      -- return the histograms unchanged.
      let outside_bounds_branch :: BuilderT
  (Rep m)
  (State VNameSource)
  (Body (Rep (BuilderT (Rep m) (State VNameSource))))
outside_bounds_branch = BuilderT (Rep m) (State VNameSource) Result
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Body (Rep (BuilderT (Rep m) (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (BuilderT (Rep m) (State VNameSource) Result
 -> BuilderT
      (Rep m)
      (State VNameSource)
      (Body (Rep (BuilderT (Rep m) (State VNameSource)))))
-> BuilderT (Rep m) (State VNameSource) Result
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Body (Rep (BuilderT (Rep m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Result -> BuilderT (Rep m) (State VNameSource) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> BuilderT (Rep m) (State VNameSource) Result)
-> Result -> BuilderT (Rep m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes [VName]
hist
          oob :: BuilderT (Rep m) (State VNameSource) (ExpT (Rep m))
oob = case [VName]
hist of
            [] -> SubExp
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Exp (Rep (BuilderT (Rep m) (State VNameSource))))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp
 -> BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource)))))
-> SubExp
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Exp (Rep (BuilderT (Rep m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
            VName
arr : [VName]
_ -> VName
-> [BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Exp (Rep (BuilderT (Rep m) (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eOutOfBounds VName
arr [SubExp
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Exp (Rep (BuilderT (Rep m) (State VNameSource))))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
idx]

      [Char]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"new_histo" (ExpT (Rep m) -> BuilderT (Rep m) (State VNameSource) [VName])
-> (Builder (Rep m) (Body (Rep m))
    -> BuilderT (Rep m) (State VNameSource) (ExpT (Rep m)))
-> Builder (Rep m) (Body (Rep m))
-> BuilderT (Rep m) (State VNameSource) [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BuilderT
  (Rep m)
  (State VNameSource)
  (Exp (Rep (BuilderT (Rep m) (State VNameSource))))
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Body (Rep (BuilderT (Rep m) (State VNameSource))))
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Body (Rep (BuilderT (Rep m) (State VNameSource))))
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Exp (Rep (BuilderT (Rep m) (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf BuilderT (Rep m) (State VNameSource) (ExpT (Rep m))
BuilderT
  (Rep m)
  (State VNameSource)
  (Exp (Rep (BuilderT (Rep m) (State VNameSource))))
oob BuilderT
  (Rep m)
  (State VNameSource)
  (Body (Rep (BuilderT (Rep m) (State VNameSource))))
outside_bounds_branch (Builder (Rep m) (Body (Rep m))
 -> BuilderT (Rep m) (State VNameSource) [VName])
-> Builder (Rep m) (Body (Rep m))
-> BuilderT (Rep m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
        BuilderT (Rep m) (State VNameSource) Result
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Body (Rep (BuilderT (Rep m) (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (BuilderT (Rep m) (State VNameSource) Result
 -> BuilderT
      (Rep m)
      (State VNameSource)
      (Body (Rep (BuilderT (Rep m) (State VNameSource)))))
-> BuilderT (Rep m) (State VNameSource) Result
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Body (Rep (BuilderT (Rep m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ do
          -- Read values from histogram.
          [SubExp]
h_val <- [VName]
-> (VName -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BuilderT (Rep m) (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
hist ((VName -> BuilderT (Rep m) (State VNameSource) SubExp)
 -> BuilderT (Rep m) (State VNameSource) [SubExp])
-> (VName -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BuilderT (Rep m) (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
arr -> do
            Type
arr_t <- VName -> BuilderT (Rep m) (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
            [Char]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"read_hist" (Exp (Rep (BuilderT (Rep m) (State VNameSource)))
 -> BuilderT (Rep m) (State VNameSource) SubExp)
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m)) -> BasicOp -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
idx]

          -- Apply operator.
          Result
h_val' <-
            Lambda (Rep (BuilderT (Rep m) (State VNameSource)))
-> [Exp (Rep (BuilderT (Rep m) (State VNameSource)))]
-> BuilderT (Rep m) (State VNameSource) Result
forall (m :: * -> *).
Transformer m =>
Lambda (Rep m) -> [Exp (Rep m)] -> m Result
bindLambda (HistOp (Rep m) -> Lambda (Rep m)
forall rep. HistOp rep -> Lambda rep
histOp HistOp (Rep m)
op) ([Exp (Rep (BuilderT (Rep m) (State VNameSource)))]
 -> BuilderT (Rep m) (State VNameSource) Result)
-> [Exp (Rep (BuilderT (Rep m) (State VNameSource)))]
-> BuilderT (Rep m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
              (SubExp -> ExpT (Rep m)) -> [SubExp] -> [ExpT (Rep m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) ([SubExp] -> [ExpT (Rep m)]) -> [SubExp] -> [ExpT (Rep m)]
forall a b. (a -> b) -> a -> b
$ [SubExp]
h_val [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val

          -- Write values back to histograms.
          [VName]
hist' <- [(VName, SubExpRes)]
-> ((VName, SubExpRes)
    -> BuilderT (Rep m) (State VNameSource) VName)
-> BuilderT (Rep m) (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> Result -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
hist Result
h_val') (((VName, SubExpRes) -> BuilderT (Rep m) (State VNameSource) VName)
 -> BuilderT (Rep m) (State VNameSource) [VName])
-> ((VName, SubExpRes)
    -> BuilderT (Rep m) (State VNameSource) VName)
-> BuilderT (Rep m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExpRes Certs
cs SubExp
v) -> do
            Type
arr_t <- VName -> BuilderT (Rep m) (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
            Certs
-> BuilderT (Rep m) (State VNameSource) VName
-> BuilderT (Rep m) (State VNameSource) VName
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (BuilderT (Rep m) (State VNameSource) VName
 -> BuilderT (Rep m) (State VNameSource) VName)
-> (ExpT (Rep m) -> BuilderT (Rep m) (State VNameSource) VName)
-> ExpT (Rep m)
-> BuilderT (Rep m) (State VNameSource) VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char]
-> VName
-> Slice SubExp
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
letInPlace [Char]
"hist_out" VName
arr (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
idx]) (ExpT (Rep m) -> BuilderT (Rep m) (State VNameSource) VName)
-> ExpT (Rep m) -> BuilderT (Rep m) (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
              BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m)) -> BasicOp -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v

          Result -> BuilderT (Rep m) (State VNameSource) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> BuilderT (Rep m) (State VNameSource) Result)
-> Result -> BuilderT (Rep m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes [VName]
hist'

    Body (Rep m) -> Builder (Rep m) (Body (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Rep m) -> Builder (Rep m) (Body (Rep m)))
-> Body (Rep m) -> Builder (Rep m) (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Body (Rep m)
forall rep. Buildable rep => [SubExp] -> Body rep
resultBody ([SubExp] -> Body (Rep m)) -> [SubExp] -> Body (Rep m)
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
hists_out''

  -- Wrap up the above into a for-loop.
  Pat (Rep m) -> ExpT (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep m)
pat (ExpT (Rep m) -> m ()) -> ExpT (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Rep m), SubExp)]
-> LoopForm (Rep m) -> Body (Rep m) -> ExpT (Rep m)
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(Param DeclType, SubExp)]
[(FParam (Rep m), SubExp)]
merge (VName
-> IntType
-> SubExp
-> [(LParam (Rep m), VName)]
-> LoopForm (Rep m)
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
iter IntType
Int64 SubExp
len []) Body (Rep m)
loopBody

-- | Recursively first-order-transform a lambda.
transformLambda ::
  ( MonadFreshNames m,
    Buildable rep,
    BuilderOps rep,
    LocalScope somerep m,
    SameScope somerep rep,
    LetDec rep ~ LetDec SOACS,
    CanBeAliased (Op rep)
  ) =>
  Lambda ->
  m (AST.Lambda rep)
transformLambda :: Lambda SOACS -> m (Lambda rep)
transformLambda (Lambda [LParam SOACS]
params BodyT SOACS
body [Type]
rettype) = do
  Body rep
body' <-
    Builder rep (Body rep) -> m (Body rep)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder rep (Body rep) -> m (Body rep))
-> Builder rep (Body rep) -> m (Body rep)
forall a b. (a -> b) -> a -> b
$
      Scope rep -> Builder rep (Body rep) -> Builder rep (Body rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope rep
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
[LParam SOACS]
params) (Builder rep (Body rep) -> Builder rep (Body rep))
-> Builder rep (Body rep) -> Builder rep (Body rep)
forall a b. (a -> b) -> a -> b
$
        BodyT SOACS
-> BuilderT
     rep
     (State VNameSource)
     (Body (Rep (BuilderT rep (State VNameSource))))
forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
BodyT SOACS -> m (Body (Rep m))
transformBody BodyT SOACS
body
  Lambda rep -> m (Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda rep -> m (Lambda rep)) -> Lambda rep -> m (Lambda rep)
forall a b. (a -> b) -> a -> b
$ [LParam rep] -> Body rep -> [Type] -> Lambda rep
forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda [LParam rep]
[LParam SOACS]
params Body rep
body' [Type]
rettype

letwith :: Transformer m => [VName] -> SubExp -> [SubExp] -> m [VName]
letwith :: [VName] -> SubExp -> [SubExp] -> m [VName]
letwith [VName]
ks SubExp
i [SubExp]
vs = do
  let update :: VName -> SubExp -> m VName
update VName
k SubExp
v = do
        Type
k_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
k
        case Type
k_t of
          Acc {} ->
            [Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"lw_acc" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v
          Type
_ ->
            [Char] -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
letInPlace [Char]
"lw_dest" VName
k (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
k_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i]) (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v
  (VName -> SubExp -> m VName) -> [VName] -> [SubExp] -> m [VName]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM VName -> SubExp -> m VName
update [VName]
ks [SubExp]
vs

bindLambda ::
  Transformer m =>
  AST.Lambda (Rep m) ->
  [AST.Exp (Rep m)] ->
  m Result
bindLambda :: Lambda (Rep m) -> [Exp (Rep m)] -> m Result
bindLambda (Lambda [LParam (Rep m)]
params BodyT (Rep m)
body [Type]
_) [Exp (Rep m)]
args = do
  [(Param Type, Exp (Rep m))]
-> ((Param Type, Exp (Rep m)) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [Exp (Rep m)] -> [(Param Type, Exp (Rep m))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
[LParam (Rep m)]
params [Exp (Rep m)]
args) (((Param Type, Exp (Rep m)) -> m ()) -> m ())
-> ((Param Type, Exp (Rep m)) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
param, Exp (Rep m)
arg) ->
    if Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
param
      then [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
param] Exp (Rep m)
arg
      else [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
param] (Exp (Rep m) -> m ()) -> m (Exp (Rep m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
m (Exp (Rep m)) -> m (Exp (Rep m))
eCopy (Exp (Rep m) -> m (Exp (Rep m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp (Rep m)
arg)
  BodyT (Rep m) -> m Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind BodyT (Rep m)
body

loopMerge :: [Ident] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge :: [Ident] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge [Ident]
vars = [(Ident, Uniqueness)] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge' ([(Ident, Uniqueness)] -> [SubExp] -> [(Param DeclType, SubExp)])
-> [(Ident, Uniqueness)] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ [Ident] -> [Uniqueness] -> [(Ident, Uniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
vars ([Uniqueness] -> [(Ident, Uniqueness)])
-> [Uniqueness] -> [(Ident, Uniqueness)]
forall a b. (a -> b) -> a -> b
$ Uniqueness -> [Uniqueness]
forall a. a -> [a]
repeat Uniqueness
Unique

loopMerge' :: [(Ident, Uniqueness)] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge' :: [(Ident, Uniqueness)] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge' [(Ident, Uniqueness)]
vars [SubExp]
vals =
  [ (VName -> DeclType -> Param DeclType
forall dec. VName -> dec -> Param dec
Param VName
pname (DeclType -> Param DeclType) -> DeclType -> Param DeclType
forall a b. (a -> b) -> a -> b
$ Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
ptype Uniqueness
u, SubExp
val)
    | ((Ident VName
pname Type
ptype, Uniqueness
u), SubExp
val) <- [(Ident, Uniqueness)]
-> [SubExp] -> [((Ident, Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Ident, Uniqueness)]
vars [SubExp]
vals
  ]

-- Note [Translation of Screma]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
--
-- Screma is the most general SOAC.  It is translated by constructing
-- a loop that contains several groups of parameters, in this order:
--
-- (0) Scan accumulator, initialised with neutral element.
-- (1) Scan results, initialised with Scratch.
-- (2) Reduce results (also functioning as accumulators),
--     initialised with neutral element.
-- (3) Map results, mostly initialised with Scratch.
--
-- However, category (3) is a little more tricky in the case where one
-- of the results is an Acc.  In that case, the result is not an
-- array, but another Acc.  Any Acc result of a Map must correspond to
-- an Acc that is an input to the map, and the result is initialised
-- to be that input.  This requires a 1:1 relationship between Acc
-- inputs and Acc outputs, which the type checker should enforce.
-- There is no guarantee that the map results appear in any particular
-- order (e.g. accumulator results before non-accumulator results), so
-- we need to do a little sleuthing to establish the relationship.
--
-- Inside the loop, the non-Acc parameters to map_lam become for-in
-- parameters.  Acc parameters refer to the loop parameters for the
-- corresponding Map result instead.
--
-- Intuitively, a Screma(w,
--                       (scan_op, scan_ne),
--                       (red_op, red_ne),
--                       map_fn,
--                       {acc_input, arr_input})
--
-- then becomes
--
-- loop (scan_acc, scan_arr, red_acc, map_acc, map_arr) =
--   for i < w, x in arr_input do
--     let (a,b,map_acc',d) = map_fn(map_acc, x)
--     let scan_acc' = scan_op(scan_acc, a)
--     let scan_arr[i] = scan_acc'
--     let red_acc' = red_op(red_acc, b)
--     let map_arr[i] = d
--     in (scan_acc', scan_arr', red_acc', map_acc', map_arr)