{-# LANGUAGE TypeFamilies #-}

-- | The code generator cannot handle the array combinators (@map@ and
-- friends), so this module was written to transform them into the
-- equivalent do-loops.  The transformation is currently rather naive,
-- and - it's certainly worth considering when we can express such
-- transformations in-place.
module Futhark.Transform.FirstOrderTransform
  ( transformFunDef,
    transformConsts,
    FirstOrderRep,
    Transformer,
    transformStmRecursively,
    transformLambda,
    transformSOAC,
  )
where

import Control.Monad
import Control.Monad.State
import Data.List (find, zip4)
import Data.Map.Strict qualified as M
import Futhark.Analysis.Alias qualified as Alias
import Futhark.IR qualified as AST
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Util (chunks, splitAt3)

-- | The constraints that must hold for a rep in order to be the
-- target of first-order transformation.
type FirstOrderRep rep =
  ( Buildable rep,
    BuilderOps rep,
    LetDec SOACS ~ LetDec rep,
    LParamInfo SOACS ~ LParamInfo rep,
    Alias.AliasableRep rep
  )

-- | First-order-transform a single function, with the given scope
-- provided by top-level constants.
transformFunDef ::
  (MonadFreshNames m, FirstOrderRep torep) =>
  Scope torep ->
  FunDef SOACS ->
  m (AST.FunDef torep)
transformFunDef :: forall (m :: * -> *) torep.
(MonadFreshNames m, FirstOrderRep torep) =>
Scope torep -> FunDef SOACS -> m (FunDef torep)
transformFunDef Scope torep
consts_scope (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [(RetType SOACS, RetAls)]
rettype [FParam SOACS]
params Body SOACS
body) = do
  (Body torep
body', Stms torep
_) <- (VNameSource -> ((Body torep, Stms torep), VNameSource))
-> m (Body torep, Stms torep)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Body torep, Stms torep), VNameSource))
 -> m (Body torep, Stms torep))
-> (VNameSource -> ((Body torep, Stms torep), VNameSource))
-> m (Body torep, Stms torep)
forall a b. (a -> b) -> a -> b
$ State VNameSource (Body torep, Stms torep)
-> VNameSource -> ((Body torep, Stms torep), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Body torep, Stms torep)
 -> VNameSource -> ((Body torep, Stms torep), VNameSource))
-> State VNameSource (Body torep, Stms torep)
-> VNameSource
-> ((Body torep, Stms torep), VNameSource)
forall a b. (a -> b) -> a -> b
$ BuilderT torep (StateT VNameSource Identity) (Body torep)
-> Scope torep -> State VNameSource (Body torep, Stms torep)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT torep (StateT VNameSource Identity) (Body torep)
m Scope torep
consts_scope
  FunDef torep -> m (FunDef torep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef torep -> m (FunDef torep))
-> FunDef torep -> m (FunDef torep)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [(RetType torep, RetAls)]
-> [FParam torep]
-> Body torep
-> FunDef torep
forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [(RetType rep, RetAls)]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [(RetType torep, RetAls)]
[(RetType SOACS, RetAls)]
rettype [FParam torep]
[FParam SOACS]
params Body torep
body'
  where
    m :: BuilderT torep (StateT VNameSource Identity) (Body torep)
m = Scope torep
-> BuilderT torep (StateT VNameSource Identity) (Body torep)
-> BuilderT torep (StateT VNameSource Identity) (Body torep)
forall a.
Scope torep
-> BuilderT torep (StateT VNameSource Identity) a
-> BuilderT torep (StateT VNameSource Identity) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([FParam torep] -> Scope torep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [FParam torep]
[FParam SOACS]
params) (BuilderT torep (StateT VNameSource Identity) (Body torep)
 -> BuilderT torep (StateT VNameSource Identity) (Body torep))
-> BuilderT torep (StateT VNameSource Identity) (Body torep)
-> BuilderT torep (StateT VNameSource Identity) (Body torep)
forall a b. (a -> b) -> a -> b
$ Body SOACS
-> BuilderT
     torep
     (StateT VNameSource Identity)
     (Body (Rep (BuilderT torep (StateT VNameSource Identity))))
forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Body SOACS -> m (Body (Rep m))
transformBody Body SOACS
body

-- | First-order-transform these top-level constants.
transformConsts ::
  (MonadFreshNames m, FirstOrderRep torep) =>
  Stms SOACS ->
  m (AST.Stms torep)
transformConsts :: forall (m :: * -> *) torep.
(MonadFreshNames m, FirstOrderRep torep) =>
Stms SOACS -> m (Stms torep)
transformConsts Stms SOACS
stms =
  (((), Stms torep) -> Stms torep)
-> m ((), Stms torep) -> m (Stms torep)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Stms torep) -> Stms torep
forall a b. (a, b) -> b
snd (m ((), Stms torep) -> m (Stms torep))
-> m ((), Stms torep) -> m (Stms torep)
forall a b. (a -> b) -> a -> b
$ (VNameSource -> (((), Stms torep), VNameSource))
-> m ((), Stms torep)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (((), Stms torep), VNameSource))
 -> m ((), Stms torep))
-> (VNameSource -> (((), Stms torep), VNameSource))
-> m ((), Stms torep)
forall a b. (a -> b) -> a -> b
$ State VNameSource ((), Stms torep)
-> VNameSource -> (((), Stms torep), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource ((), Stms torep)
 -> VNameSource -> (((), Stms torep), VNameSource))
-> State VNameSource ((), Stms torep)
-> VNameSource
-> (((), Stms torep), VNameSource)
forall a b. (a -> b) -> a -> b
$ BuilderT torep (StateT VNameSource Identity) ()
-> Scope torep -> State VNameSource ((), Stms torep)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT torep (StateT VNameSource Identity) ()
m Scope torep
forall a. Monoid a => a
mempty
  where
    m :: BuilderT torep (StateT VNameSource Identity) ()
m = (Stm SOACS -> BuilderT torep (StateT VNameSource Identity) ())
-> Stms SOACS -> BuilderT torep (StateT VNameSource Identity) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm SOACS -> BuilderT torep (StateT VNameSource Identity) ()
forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Stm SOACS -> m ()
transformStmRecursively Stms SOACS
stms

-- | The constraints that a monad must uphold in order to be used for
-- first-order transformation.
type Transformer m =
  ( MonadBuilder m,
    LocalScope (Rep m) m,
    Buildable (Rep m),
    BuilderOps (Rep m),
    LParamInfo SOACS ~ LParamInfo (Rep m),
    Alias.AliasableRep (Rep m)
  )

transformBody ::
  (Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
  Body SOACS ->
  m (AST.Body (Rep m))
transformBody :: forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Body SOACS -> m (Body (Rep m))
transformBody (Body () Stms SOACS
stms Result
res) = m Result -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (m Result -> m (Body (Rep m))) -> m Result -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ do
  (Stm SOACS -> m ()) -> Stms SOACS -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm SOACS -> m ()
forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Stm SOACS -> m ()
transformStmRecursively Stms SOACS
stms
  Result -> m Result
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

-- | First transform any nested t'Body' or t'Lambda' elements, then
-- apply 'transformSOAC' if the expression is a SOAC.
transformStmRecursively ::
  (Transformer m, LetDec (Rep m) ~ LetDec SOACS) => Stm SOACS -> m ()
transformStmRecursively :: forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Stm SOACS -> m ()
transformStmRecursively (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac)) =
  StmAux () -> m () -> m ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
transformSOAC Pat (LetDec (Rep m))
Pat (LetDec SOACS)
pat (SOAC (Rep m) -> m ()) -> m (SOAC (Rep m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOACMapper SOACS (Rep m) m -> SOAC SOACS -> m (SOAC (Rep m))
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper SOACS (Rep m) m
soacTransform Op SOACS
SOAC SOACS
soac
  where
    soacTransform :: SOACMapper SOACS (Rep m) m
soacTransform = SOACMapper Any Any m
forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda = transformLambda}
transformStmRecursively (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) =
  StmAux () -> m () -> m ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep m))
Pat (LetDec SOACS)
pat (Exp (Rep m) -> m ()) -> m (Exp (Rep m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Mapper SOACS (Rep m) m -> Exp SOACS -> m (Exp (Rep m))
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper SOACS (Rep m) m
transform Exp SOACS
e
  where
    transform :: Mapper SOACS (Rep m) m
transform =
      Mapper Any Any m
forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper
        { mapOnBody = \Scope (Rep m)
scope -> Scope (Rep m) -> m (Body (Rep m)) -> m (Body (Rep m))
forall a. Scope (Rep m) -> m a -> m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope (Rep m)
scope (m (Body (Rep m)) -> m (Body (Rep m)))
-> (Body SOACS -> m (Body (Rep m)))
-> Body SOACS
-> m (Body (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body SOACS -> m (Body (Rep m))
forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Body SOACS -> m (Body (Rep m))
transformBody,
          mapOnRetType = pure,
          mapOnBranchType = pure,
          mapOnFParam = pure,
          mapOnLParam = pure,
          mapOnOp = error "Unhandled Op in first order transform"
        }

-- Produce scratch "arrays" for the Map and Scan outputs of Screma.
-- "Arrays" is in quotes because some of those may be accumulators.
resultArray :: (Transformer m) => [VName] -> [Type] -> m [VName]
resultArray :: forall (m :: * -> *).
Transformer m =>
[VName] -> [Type] -> m [VName]
resultArray [VName]
arrs [Type]
ts = do
  [Type]
arrs_ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
arrs
  let oneArray :: Type -> m VName
oneArray t :: Type
t@Acc {}
        | Just (VName
v, Type
_) <- ((VName, Type) -> Bool) -> [(VName, Type)] -> Maybe (VName, Type)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t) (Type -> Bool) -> ((VName, Type) -> Type) -> (VName, Type) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Type) -> Type
forall a b. (a, b) -> b
snd) ([VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [Type]
arrs_ts) =
            VName -> m VName
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
      oneArray Type
t =
        [Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"result" (Exp (Rep m) -> m VName) -> m (Exp (Rep m)) -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Type -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank Type
t
  (Type -> m VName) -> [Type] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Type -> m VName
oneArray [Type]
ts

-- | Transform a single 'SOAC' into a do-loop.  The body of the lambda
-- is untouched, and may or may not contain further 'SOAC's depending
-- on the given rep.
transformSOAC ::
  (Transformer m) =>
  Pat (LetDec (Rep m)) ->
  SOAC (Rep m) ->
  m ()
transformSOAC :: forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
transformSOAC Pat (LetDec (Rep m))
_ JVP {} =
  [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"transformSOAC: unhandled JVP"
transformSOAC Pat (LetDec (Rep m))
_ VJP {} =
  [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"transformSOAC: unhandled VJP"
transformSOAC Pat (LetDec (Rep m))
pat (Screma SubExp
w [VName]
arrs form :: ScremaForm (Rep m)
form@(ScremaForm [Scan (Rep m)]
scans [Reduce (Rep m)]
reds Lambda (Rep m)
map_lam)) = do
  -- See Note [Translation of Screma].
  --
  -- Start by combining all the reduction and scan parts into a single
  -- operator
  let Reduce Commutativity
_ Lambda (Rep m)
red_lam [SubExp]
red_nes = [Reduce (Rep m)] -> Reduce (Rep m)
forall rep. Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce (Rep m)]
reds
      Scan Lambda (Rep m)
scan_lam [SubExp]
scan_nes = [Scan (Rep m)] -> Scan (Rep m)
forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan (Rep m)]
scans
      ([Type]
scan_arr_ts, [Type]
_red_ts, [Type]
map_arr_ts) =
        Int -> Int -> [Type] -> ([Type], [Type], [Type])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes) ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) ([Type] -> ([Type], [Type], [Type]))
-> [Type] -> ([Type], [Type], [Type])
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm (Rep m) -> [Type]
forall rep. SubExp -> ScremaForm rep -> [Type]
scremaType SubExp
w ScremaForm (Rep m)
form

  [VName]
scan_arrs <- [VName] -> [Type] -> m [VName]
forall (m :: * -> *).
Transformer m =>
[VName] -> [Type] -> m [VName]
resultArray [] [Type]
scan_arr_ts
  [VName]
map_arrs <- [VName] -> [Type] -> m [VName]
forall (m :: * -> *).
Transformer m =>
[VName] -> [Type] -> m [VName]
resultArray [VName]
arrs [Type]
map_arr_ts

  [Param DeclType]
scanacc_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"scanacc" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) ([Type] -> m [Param DeclType]) -> [Type] -> m [Param DeclType]
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Rep m)
scan_lam
  [Param DeclType]
scanout_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"scanout" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Unique) [Type]
scan_arr_ts
  [Param DeclType]
redout_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"redout" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) ([Type] -> m [Param DeclType]) -> [Type] -> m [Param DeclType]
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Rep m)
red_lam
  [Param DeclType]
mapout_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"mapout" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Unique) [Type]
map_arr_ts

  [Type]
arr_ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
arrs
  let paramForAcc :: Type -> Maybe (Param DeclType)
paramForAcc (Acc VName
c ShapeBase SubExp
_ [Type]
_ NoUniqueness
_) = (Param DeclType -> Bool)
-> [Param DeclType] -> Maybe (Param DeclType)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (Type -> Bool
f (Type -> Bool)
-> (Param DeclType -> Type) -> Param DeclType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType) [Param DeclType]
mapout_params
        where
          f :: Type -> Bool
f (Acc VName
c2 ShapeBase SubExp
_ [Type]
_ NoUniqueness
_) = VName
c VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
c2
          f Type
_ = Bool
False
      paramForAcc Type
_ = Maybe (Param DeclType)
forall a. Maybe a
Nothing

  let merge :: [(Param DeclType, SubExp)]
merge =
        [[(Param DeclType, SubExp)]] -> [(Param DeclType, SubExp)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
          [ [Param DeclType] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
scanacc_params [SubExp]
scan_nes,
            [Param DeclType] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
scanout_params ([SubExp] -> [(Param DeclType, SubExp)])
-> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
scan_arrs,
            [Param DeclType] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
redout_params [SubExp]
red_nes,
            [Param DeclType] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
mapout_params ([SubExp] -> [(Param DeclType, SubExp)])
-> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
map_arrs
          ]
  VName
i <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i"
  let loopform :: LoopForm
loopform = VName -> IntType -> SubExp -> LoopForm
ForLoop VName
i IntType
Int64 SubExp
w
      lam_cons :: Names
lam_cons = Lambda (Aliases (Rep m)) -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda (Lambda (Aliases (Rep m)) -> Names)
-> Lambda (Aliases (Rep m)) -> Names
forall a b. (a -> b) -> a -> b
$ AliasTable -> Lambda (Rep m) -> Lambda (Aliases (Rep m))
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
forall a. Monoid a => a
mempty Lambda (Rep m)
map_lam

  Body (Rep m)
loop_body <- Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder
    (Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m)))
-> (Builder (Rep m) (Body (Rep m))
    -> Builder (Rep m) (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m))
-> m (Body (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope (Rep m)
-> Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m))
forall a.
Scope (Rep m)
-> BuilderT (Rep m) (State VNameSource) a
-> BuilderT (Rep m) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param DeclType] -> Scope (Rep m)
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge) Scope (Rep m) -> Scope (Rep m) -> Scope (Rep m)
forall a. Semigroup a => a -> a -> a
<> LoopForm -> Scope (Rep m)
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
loopform)
    (Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ do
      -- Bind the parameters to the lambda.
      [(Param (LParamInfo (Rep m)), VName, Type)]
-> ((Param (LParamInfo (Rep m)), VName, Type)
    -> BuilderT (Rep m) (State VNameSource) ())
-> BuilderT (Rep m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (LParamInfo (Rep m))]
-> [VName] -> [Type] -> [(Param (LParamInfo (Rep m)), VName, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Lambda (Rep m) -> [Param (LParamInfo (Rep m))]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
map_lam) [VName]
arrs [Type]
arr_ts) (((Param (LParamInfo (Rep m)), VName, Type)
  -> BuilderT (Rep m) (State VNameSource) ())
 -> BuilderT (Rep m) (State VNameSource) ())
-> ((Param (LParamInfo (Rep m)), VName, Type)
    -> BuilderT (Rep m) (State VNameSource) ())
-> BuilderT (Rep m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param (LParamInfo (Rep m))
p, VName
arr, Type
arr_t) ->
        case Type -> Maybe (Param DeclType)
paramForAcc Type
arr_t of
          Just Param DeclType
acc_out_p ->
            [VName]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param (LParamInfo (Rep m)) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
p] (Exp (Rep m) -> BuilderT (Rep m) (State VNameSource) ())
-> (BasicOp -> Exp (Rep m))
-> BasicOp
-> BuilderT (Rep m) (State VNameSource) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> BuilderT (Rep m) (State VNameSource) ())
-> BasicOp -> BuilderT (Rep m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
              SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
                VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$
                  Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
acc_out_p
          Maybe (Param DeclType)
Nothing
            | Param (LParamInfo (Rep m)) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
p VName -> Names -> Bool
`nameIn` Names
lam_cons -> do
                VName
p' <-
                  [Char]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString (Param (LParamInfo (Rep m)) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
p)) (Exp (Rep m) -> BuilderT (Rep m) (State VNameSource) VName)
-> (BasicOp -> Exp (Rep m))
-> BasicOp
-> BuilderT (Rep m) (State VNameSource) VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> BuilderT (Rep m) (State VNameSource) VName)
-> BasicOp -> BuilderT (Rep m) (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
                    VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
                      Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i]
                [VName]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param (LParamInfo (Rep m)) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
p] (Exp (Rep (BuilderT (Rep m) (State VNameSource)))
 -> BuilderT (Rep m) (State VNameSource) ())
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT (Rep m) (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
p'
            | Bool
otherwise ->
                [VName]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param (LParamInfo (Rep m)) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
p] (Exp (Rep m) -> BuilderT (Rep m) (State VNameSource) ())
-> (Slice SubExp -> Exp (Rep m))
-> Slice SubExp
-> BuilderT (Rep m) (State VNameSource) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m))
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> Exp (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BuilderT (Rep m) (State VNameSource) ())
-> Slice SubExp -> BuilderT (Rep m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
                  Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i]

      -- Insert the statements of the lambda.  We have taken care to
      -- ensure that the parameters are bound at this point.
      (Stm (Rep m) -> BuilderT (Rep m) (State VNameSource) ())
-> Stms (Rep m) -> BuilderT (Rep m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm (Rep m) -> BuilderT (Rep m) (State VNameSource) ()
Stm (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stms (Rep m) -> BuilderT (Rep m) (State VNameSource) ())
-> Stms (Rep m) -> BuilderT (Rep m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Body (Rep m) -> Stms (Rep m)
forall rep. Body rep -> Stms rep
bodyStms (Body (Rep m) -> Stms (Rep m)) -> Body (Rep m) -> Stms (Rep m)
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> Body (Rep m)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
map_lam
      -- Split into scan results, reduce results, and map results.
      let (Result
scan_res, Result
red_res, Result
map_res) =
            Int -> Int -> Result -> (Result, Result, Result)
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes) ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) (Result -> (Result, Result, Result))
-> Result -> (Result, Result, Result)
forall a b. (a -> b) -> a -> b
$
              Body (Rep m) -> Result
forall rep. Body rep -> Result
bodyResult (Body (Rep m) -> Result) -> Body (Rep m) -> Result
forall a b. (a -> b) -> a -> b
$
                Lambda (Rep m) -> Body (Rep m)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
map_lam

      Result
scan_res' <-
        Lambda (Rep (BuilderT (Rep m) (State VNameSource)))
-> [BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
-> BuilderT (Rep m) (State VNameSource) Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep m)
Lambda (Rep (BuilderT (Rep m) (State VNameSource)))
scan_lam ([BuilderT
    (Rep m)
    (State VNameSource)
    (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
 -> BuilderT (Rep m) (State VNameSource) Result)
-> [BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
-> BuilderT (Rep m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
          (SubExp
 -> BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource)))))
-> [SubExp]
-> [BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
forall a b. (a -> b) -> [a] -> [b]
map (Exp (Rep m) -> BuilderT (Rep m) (State VNameSource) (Exp (Rep m))
forall a. a -> BuilderT (Rep m) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep m) -> BuilderT (Rep m) (State VNameSource) (Exp (Rep m)))
-> (SubExp -> Exp (Rep m))
-> SubExp
-> BuilderT (Rep m) (State VNameSource) (Exp (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m))
-> (SubExp -> BasicOp) -> SubExp -> Exp (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) ([SubExp]
 -> [BuilderT
       (Rep m)
       (State VNameSource)
       (Exp (Rep (BuilderT (Rep m) (State VNameSource))))])
-> [SubExp]
-> [BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
forall a b. (a -> b) -> a -> b
$
            (Param DeclType -> SubExp) -> [Param DeclType] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param DeclType -> VName) -> Param DeclType -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> VName
forall dec. Param dec -> VName
paramName) [Param DeclType]
scanacc_params [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
scan_res
      Result
red_res' <-
        Lambda (Rep (BuilderT (Rep m) (State VNameSource)))
-> [BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
-> BuilderT (Rep m) (State VNameSource) Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep m)
Lambda (Rep (BuilderT (Rep m) (State VNameSource)))
red_lam ([BuilderT
    (Rep m)
    (State VNameSource)
    (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
 -> BuilderT (Rep m) (State VNameSource) Result)
-> [BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
-> BuilderT (Rep m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
          (SubExp
 -> BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource)))))
-> [SubExp]
-> [BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
forall a b. (a -> b) -> [a] -> [b]
map (Exp (Rep m) -> BuilderT (Rep m) (State VNameSource) (Exp (Rep m))
forall a. a -> BuilderT (Rep m) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep m) -> BuilderT (Rep m) (State VNameSource) (Exp (Rep m)))
-> (SubExp -> Exp (Rep m))
-> SubExp
-> BuilderT (Rep m) (State VNameSource) (Exp (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m))
-> (SubExp -> BasicOp) -> SubExp -> Exp (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) ([SubExp]
 -> [BuilderT
       (Rep m)
       (State VNameSource)
       (Exp (Rep (BuilderT (Rep m) (State VNameSource))))])
-> [SubExp]
-> [BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
forall a b. (a -> b) -> a -> b
$
            (Param DeclType -> SubExp) -> [Param DeclType] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param DeclType -> VName) -> Param DeclType -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> VName
forall dec. Param dec -> VName
paramName) [Param DeclType]
redout_params [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
red_res

      -- Write the scan accumulator to the scan result arrays.
      [VName]
scan_outarrs <-
        Certs
-> BuilderT (Rep m) (State VNameSource) [VName]
-> BuilderT (Rep m) (State VNameSource) [VName]
forall a.
Certs
-> BuilderT (Rep m) (State VNameSource) a
-> BuilderT (Rep m) (State VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying ((SubExpRes -> Certs) -> Result -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts Result
scan_res) (BuilderT (Rep m) (State VNameSource) [VName]
 -> BuilderT (Rep m) (State VNameSource) [VName])
-> BuilderT (Rep m) (State VNameSource) [VName]
-> BuilderT (Rep m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
          [VName]
-> SubExp
-> [SubExp]
-> BuilderT (Rep m) (State VNameSource) [VName]
forall (m :: * -> *).
Transformer m =>
[VName] -> SubExp -> [SubExp] -> m [VName]
letwith ((Param DeclType -> VName) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> VName
forall dec. Param dec -> VName
paramName [Param DeclType]
scanout_params) (VName -> SubExp
Var VName
i) ([SubExp] -> BuilderT (Rep m) (State VNameSource) [VName])
-> [SubExp] -> BuilderT (Rep m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
            (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
scan_res'

      -- Write the map results to the map result arrays.
      [VName]
map_outarrs <-
        Certs
-> BuilderT (Rep m) (State VNameSource) [VName]
-> BuilderT (Rep m) (State VNameSource) [VName]
forall a.
Certs
-> BuilderT (Rep m) (State VNameSource) a
-> BuilderT (Rep m) (State VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying ((SubExpRes -> Certs) -> Result -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts Result
map_res) (BuilderT (Rep m) (State VNameSource) [VName]
 -> BuilderT (Rep m) (State VNameSource) [VName])
-> BuilderT (Rep m) (State VNameSource) [VName]
-> BuilderT (Rep m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
          [VName]
-> SubExp
-> [SubExp]
-> BuilderT (Rep m) (State VNameSource) [VName]
forall (m :: * -> *).
Transformer m =>
[VName] -> SubExp -> [SubExp] -> m [VName]
letwith ((Param DeclType -> VName) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> VName
forall dec. Param dec -> VName
paramName [Param DeclType]
mapout_params) (VName -> SubExp
Var VName
i) ([SubExp] -> BuilderT (Rep m) (State VNameSource) [VName])
-> [SubExp] -> BuilderT (Rep m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
            (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
map_res

      Body (Rep m) -> Builder (Rep m) (Body (Rep m))
forall a. a -> BuilderT (Rep m) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Rep m) -> Builder (Rep m) (Body (Rep m)))
-> ([Result] -> Body (Rep m))
-> [Result]
-> Builder (Rep m) (Body (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms (Rep m) -> Result -> Body (Rep m)
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms (Rep m)
forall a. Monoid a => a
mempty (Result -> Body (Rep m))
-> ([Result] -> Result) -> [Result] -> Body (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([Result] -> Builder (Rep m) (Body (Rep m)))
-> [Result] -> Builder (Rep m) (Body (Rep m))
forall a b. (a -> b) -> a -> b
$
        [ Result
scan_res',
          [VName] -> Result
varsRes [VName]
scan_outarrs,
          Result
red_res',
          [VName] -> Result
varsRes [VName]
map_outarrs
        ]

  -- We need to discard the final scan accumulators, as they are not
  -- bound in the original pattern.
  [VName]
names <-
    ([VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ Pat (LetDec (Rep m)) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec (Rep m))
pat)
      ([VName] -> [VName]) -> m [VName] -> m [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([Param DeclType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param DeclType]
scanacc_params) ([Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"discard")
  [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
names (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Rep m), SubExp)]
-> LoopForm -> Body (Rep m) -> Exp (Rep m)
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param DeclType, SubExp)]
[(FParam (Rep m), SubExp)]
merge LoopForm
loopform Body (Rep m)
loop_body
transformSOAC Pat (LetDec (Rep m))
pat (Stream SubExp
w [VName]
arrs [SubExp]
nes Lambda (Rep m)
lam) = do
  -- Create a loop that repeatedly applies the lambda body to a
  -- chunksize of 1.  Hopefully this will lead to this outer loop
  -- being the only one, as all the innermost one can be simplified
  -- array (as they will have one iteration each).
  let (Param Type
chunk_size_param, [Param Type]
fold_params, [Param Type]
chunk_params) =
        Int -> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param Type] -> (Param Type, [Param Type], [Param Type]))
-> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> [Param (LParamInfo (Rep m))]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam

  [(Param DeclType, SubExp)]
mapout_merge <- [Type]
-> (Type -> m (Param DeclType, SubExp))
-> m [(Param DeclType, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Rep m)
lam) ((Type -> m (Param DeclType, SubExp))
 -> m [(Param DeclType, SubExp)])
-> (Type -> m (Param DeclType, SubExp))
-> m [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ \Type
t ->
    let t' :: Type
t' = Type
t Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
        scratch :: Exp (Rep m)
scratch = BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t') (Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
t')
     in (,)
          (Param DeclType -> SubExp -> (Param DeclType, SubExp))
-> m (Param DeclType) -> m (SubExp -> (Param DeclType, SubExp))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"stream_mapout" (Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
t' Uniqueness
Unique)
          m (SubExp -> (Param DeclType, SubExp))
-> m SubExp -> m (Param DeclType, SubExp)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"stream_mapout_scratch" Exp (Rep m)
scratch

  -- We need to copy the neutral elements because they may be consumed
  -- in the body of the Stream.
  let copyIfArray :: SubExp -> m SubExp
copyIfArray SubExp
se = do
        Type
se_t <- SubExp -> m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
        case (Type
se_t, SubExp
se) of
          (Array {}, Var VName
v) ->
            [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
v) (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
forall a. Monoid a => a
mempty SubExp
se
          (Type, SubExp)
_ -> SubExp -> m SubExp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
  [SubExp]
nes' <- (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> m SubExp
forall {m :: * -> *}. MonadBuilder m => SubExp -> m SubExp
copyIfArray [SubExp]
nes

  let onType :: TypeBase shape NoUniqueness -> TypeBase shape Uniqueness
onType TypeBase shape NoUniqueness
t = TypeBase shape NoUniqueness
t TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
`toDecl` Uniqueness
Unique
      merge :: [(Param DeclType, SubExp)]
merge = [Param DeclType] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param Type -> Param DeclType) -> [Param Type] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> DeclType) -> Param Type -> Param DeclType
forall a b. (a -> b) -> Param a -> Param b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> DeclType
forall {shape}.
TypeBase shape NoUniqueness -> TypeBase shape Uniqueness
onType) [Param Type]
fold_params) [SubExp]
nes' [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
mapout_merge
      merge_params :: [Param DeclType]
merge_params = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge
      mapout_params :: [Param DeclType]
mapout_params = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
mapout_merge

  VName
i <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i"

  let loop_form :: LoopForm
loop_form = VName -> IntType -> SubExp -> LoopForm
ForLoop VName
i IntType
Int64 SubExp
w

  [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size_param] (Exp (Rep m) -> m ()) -> (SubExp -> Exp (Rep m)) -> SubExp -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m))
-> (SubExp -> BasicOp) -> SubExp -> Exp (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> m ()) -> SubExp -> m ()
forall a b. (a -> b) -> a -> b
$
    IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1

  Body (Rep m)
loop_body <- Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$
    Scope (Rep m)
-> Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m))
forall a.
Scope (Rep m)
-> BuilderT (Rep m) (State VNameSource) a
-> BuilderT (Rep m) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm -> Scope (Rep m)
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
loop_form Scope (Rep m) -> Scope (Rep m) -> Scope (Rep m)
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope (Rep m)
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
merge_params) (Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ do
      let slice :: [DimIndex SubExp]
slice = [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (VName -> SubExp
Var VName
i) (VName -> SubExp
Var (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size_param)) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
      [(Param Type, VName)]
-> ((Param Type, VName) -> BuilderT (Rep m) (State VNameSource) ())
-> BuilderT (Rep m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
chunk_params [VName]
arrs) (((Param Type, VName) -> BuilderT (Rep m) (State VNameSource) ())
 -> BuilderT (Rep m) (State VNameSource) ())
-> ((Param Type, VName) -> BuilderT (Rep m) (State VNameSource) ())
-> BuilderT (Rep m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) ->
        [VName]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Rep m) -> BuilderT (Rep m) (State VNameSource) ())
-> (Slice SubExp -> Exp (Rep m))
-> Slice SubExp
-> BuilderT (Rep m) (State VNameSource) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m))
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> Exp (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BuilderT (Rep m) (State VNameSource) ())
-> Slice SubExp -> BuilderT (Rep m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
          Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) [DimIndex SubExp]
slice

      (Result
res, Result
mapout_res) <- Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) (Result -> (Result, Result))
-> BuilderT (Rep m) (State VNameSource) Result
-> BuilderT (Rep m) (State VNameSource) (Result, Result)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Lambda (Rep m) -> Body (Rep m)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
lam)

      [SubExp]
res' <- (SubExpRes -> BuilderT (Rep m) (State VNameSource) SubExp)
-> Result -> BuilderT (Rep m) (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SubExp -> BuilderT (Rep m) (State VNameSource) SubExp
forall {m :: * -> *}. MonadBuilder m => SubExp -> m SubExp
copyIfArray (SubExp -> BuilderT (Rep m) (State VNameSource) SubExp)
-> (SubExpRes -> SubExp)
-> SubExpRes
-> BuilderT (Rep m) (State VNameSource) SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
res

      [SubExp]
mapout_res' <- [(Param DeclType, SubExpRes)]
-> ((Param DeclType, SubExpRes)
    -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BuilderT (Rep m) (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param DeclType] -> Result -> [(Param DeclType, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
mapout_params Result
mapout_res) (((Param DeclType, SubExpRes)
  -> BuilderT (Rep m) (State VNameSource) SubExp)
 -> BuilderT (Rep m) (State VNameSource) [SubExp])
-> ((Param DeclType, SubExpRes)
    -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BuilderT (Rep m) (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$ \(Param DeclType
p, SubExpRes Certs
cs SubExp
se) ->
        Certs
-> BuilderT (Rep m) (State VNameSource) SubExp
-> BuilderT (Rep m) (State VNameSource) SubExp
forall a.
Certs
-> BuilderT (Rep m) (State VNameSource) a
-> BuilderT (Rep m) (State VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (BuilderT (Rep m) (State VNameSource) SubExp
 -> BuilderT (Rep m) (State VNameSource) SubExp)
-> (BasicOp -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BasicOp
-> BuilderT (Rep m) (State VNameSource) SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"mapout_res" (Exp (Rep m) -> BuilderT (Rep m) (State VNameSource) SubExp)
-> (BasicOp -> Exp (Rep m))
-> BasicOp
-> BuilderT (Rep m) (State VNameSource) SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BasicOp -> BuilderT (Rep m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
          Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p) (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param DeclType
p) [DimIndex SubExp]
slice) SubExp
se

      Stms (Rep (BuilderT (Rep m) (State VNameSource)))
-> Result
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Body (Rep (BuilderT (Rep m) (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
Stms (Rep (BuilderT (Rep m) (State VNameSource)))
forall a. Monoid a => a
mempty (Result
 -> BuilderT
      (Rep m)
      (State VNameSource)
      (Body (Rep (BuilderT (Rep m) (State VNameSource)))))
-> Result
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Body (Rep (BuilderT (Rep m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes ([SubExp] -> Result) -> [SubExp] -> Result
forall a b. (a -> b) -> a -> b
$ [SubExp]
res' [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
mapout_res'

  Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep m))
pat (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Rep m), SubExp)]
-> LoopForm -> Body (Rep m) -> Exp (Rep m)
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param DeclType, SubExp)]
[(FParam (Rep m), SubExp)]
merge LoopForm
loop_form Body (Rep m)
loop_body
transformSOAC Pat (LetDec (Rep m))
pat (Scatter SubExp
len [VName]
ivs Lambda (Rep m)
lam [(ShapeBase SubExp, Int, VName)]
as) = do
  VName
iter <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"write_iter"

  let ([ShapeBase SubExp]
as_ws, [Int]
as_ns, [VName]
as_vs) = [(ShapeBase SubExp, Int, VName)]
-> ([ShapeBase SubExp], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(ShapeBase SubExp, Int, VName)]
as
  [Type]
ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
as_vs
  [Ident]
asOuts <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([Char] -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent [Char]
"write_out") [Type]
ts

  -- Scatter is in-place, so we use the input array as the output array.
  let merge :: [(Param DeclType, SubExp)]
merge = [Ident] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge [Ident]
asOuts ([SubExp] -> [(Param DeclType, SubExp)])
-> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
as_vs
  Body (Rep m)
loopBody <- Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$
    Scope (Rep m)
-> Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m))
forall a.
Scope (Rep m)
-> BuilderT (Rep m) (State VNameSource) a
-> BuilderT (Rep m) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (VName -> NameInfo (Rep m) -> Scope (Rep m) -> Scope (Rep m)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
iter (IntType -> NameInfo (Rep m)
forall rep. IntType -> NameInfo rep
IndexName IntType
Int64) (Scope (Rep m) -> Scope (Rep m)) -> Scope (Rep m) -> Scope (Rep m)
forall a b. (a -> b) -> a -> b
$ [Param DeclType] -> Scope (Rep m)
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams ([Param DeclType] -> Scope (Rep m))
-> [Param DeclType] -> Scope (Rep m)
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge) (Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ do
      [SubExp]
ivs' <- [VName]
-> (VName -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BuilderT (Rep m) (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ivs ((VName -> BuilderT (Rep m) (State VNameSource) SubExp)
 -> BuilderT (Rep m) (State VNameSource) [SubExp])
-> (VName -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BuilderT (Rep m) (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
iv -> do
        Type
iv_t <- VName -> BuilderT (Rep m) (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
iv
        [Char]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"write_iv" (Exp (Rep (BuilderT (Rep m) (State VNameSource)))
 -> BuilderT (Rep m) (State VNameSource) SubExp)
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT (Rep m) (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
iv (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
iv_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
iter]
      Result
ivs'' <- Lambda (Rep (BuilderT (Rep m) (State VNameSource)))
-> [Exp (Rep (BuilderT (Rep m) (State VNameSource)))]
-> BuilderT (Rep m) (State VNameSource) Result
forall (m :: * -> *).
Transformer m =>
Lambda (Rep m) -> [Exp (Rep m)] -> m Result
bindLambda Lambda (Rep m)
Lambda (Rep (BuilderT (Rep m) (State VNameSource)))
lam ((SubExp -> Exp (Rep m)) -> [SubExp] -> [Exp (Rep m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m))
-> (SubExp -> BasicOp) -> SubExp -> Exp (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) [SubExp]
ivs')

      let indexes :: [(ShapeBase SubExp, VName, [(Result, SubExpRes)])]
indexes = [(ShapeBase SubExp, Int, VName)]
-> Result -> [(ShapeBase SubExp, VName, [(Result, SubExpRes)])]
forall array a.
[(ShapeBase SubExp, Int, array)]
-> [a] -> [(ShapeBase SubExp, array, [([a], a)])]
groupScatterResults ([ShapeBase SubExp]
-> [Int] -> [VName] -> [(ShapeBase SubExp, Int, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [ShapeBase SubExp]
as_ws [Int]
as_ns ([VName] -> [(ShapeBase SubExp, Int, VName)])
-> [VName] -> [(ShapeBase SubExp, Int, VName)]
forall a b. (a -> b) -> a -> b
$ (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
asOuts) Result
ivs''

      [VName]
ress <- [(ShapeBase SubExp, VName, [(Result, SubExpRes)])]
-> ((ShapeBase SubExp, VName, [(Result, SubExpRes)])
    -> BuilderT (Rep m) (State VNameSource) VName)
-> BuilderT (Rep m) (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ShapeBase SubExp, VName, [(Result, SubExpRes)])]
indexes (((ShapeBase SubExp, VName, [(Result, SubExpRes)])
  -> BuilderT (Rep m) (State VNameSource) VName)
 -> BuilderT (Rep m) (State VNameSource) [VName])
-> ((ShapeBase SubExp, VName, [(Result, SubExpRes)])
    -> BuilderT (Rep m) (State VNameSource) VName)
-> BuilderT (Rep m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(ShapeBase SubExp
_, VName
arr, [(Result, SubExpRes)]
indexes') -> do
        Type
arr_t <- VName -> BuilderT (Rep m) (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
        let saveInArray :: VName
-> (Result, SubExpRes)
-> BuilderT (Rep m) (State VNameSource) VName
saveInArray VName
arr' (Result
indexCur, SubExpRes Certs
value_cs SubExp
valueCur) =
              Certs
-> BuilderT (Rep m) (State VNameSource) VName
-> BuilderT (Rep m) (State VNameSource) VName
forall a.
Certs
-> BuilderT (Rep m) (State VNameSource) a
-> BuilderT (Rep m) (State VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying ((SubExpRes -> Certs) -> Result -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts Result
indexCur Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
value_cs) (BuilderT (Rep m) (State VNameSource) VName
 -> BuilderT (Rep m) (State VNameSource) VName)
-> (Exp (Rep m) -> BuilderT (Rep m) (State VNameSource) VName)
-> Exp (Rep m)
-> BuilderT (Rep m) (State VNameSource) VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"write_out" (Exp (Rep m) -> BuilderT (Rep m) (State VNameSource) VName)
-> Exp (Rep m) -> BuilderT (Rep m) (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
                BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
                  Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Safe VName
arr' (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> DimIndex SubExp) -> Result -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (SubExpRes -> SubExp) -> SubExpRes -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
indexCur) SubExp
valueCur

        (VName
 -> (Result, SubExpRes)
 -> BuilderT (Rep m) (State VNameSource) VName)
-> VName
-> [(Result, SubExpRes)]
-> BuilderT (Rep m) (State VNameSource) VName
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM VName
-> (Result, SubExpRes)
-> BuilderT (Rep m) (State VNameSource) VName
saveInArray VName
arr [(Result, SubExpRes)]
indexes'
      Body (Rep m) -> Builder (Rep m) (Body (Rep m))
forall a. a -> BuilderT (Rep m) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Rep m) -> Builder (Rep m) (Body (Rep m)))
-> Body (Rep m) -> Builder (Rep m) (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Body (Rep m)
forall rep. Buildable rep => [SubExp] -> Body rep
resultBody ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
ress)
  Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep m))
pat (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Rep m), SubExp)]
-> LoopForm -> Body (Rep m) -> Exp (Rep m)
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param DeclType, SubExp)]
[(FParam (Rep m), SubExp)]
merge (VName -> IntType -> SubExp -> LoopForm
ForLoop VName
iter IntType
Int64 SubExp
len) Body (Rep m)
loopBody
transformSOAC Pat (LetDec (Rep m))
pat (Hist SubExp
len [VName]
imgs [HistOp (Rep m)]
ops Lambda (Rep m)
bucket_fun) = do
  VName
iter <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"iter"

  -- Bind arguments to parameters for the merge-variables.
  [Type]
hists_ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType ([VName] -> m [Type]) -> [VName] -> m [Type]
forall a b. (a -> b) -> a -> b
$ (HistOp (Rep m) -> [VName]) -> [HistOp (Rep m)] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp (Rep m) -> [VName]
forall rep. HistOp rep -> [VName]
histDest [HistOp (Rep m)]
ops
  [Ident]
hists_out <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([Char] -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent [Char]
"dests") [Type]
hists_ts
  let merge :: [(Param DeclType, SubExp)]
merge = [Ident] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge [Ident]
hists_out ([SubExp] -> [(Param DeclType, SubExp)])
-> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (HistOp (Rep m) -> [SubExp]) -> [HistOp (Rep m)] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp])
-> (HistOp (Rep m) -> [VName]) -> HistOp (Rep m) -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Rep m) -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp (Rep m)]
ops

  -- Bind lambda-bodies for operators.
  let iter_scope :: Scope (Rep m)
iter_scope = VName -> NameInfo (Rep m) -> Scope (Rep m) -> Scope (Rep m)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
iter (IntType -> NameInfo (Rep m)
forall rep. IntType -> NameInfo rep
IndexName IntType
Int64) (Scope (Rep m) -> Scope (Rep m)) -> Scope (Rep m) -> Scope (Rep m)
forall a b. (a -> b) -> a -> b
$ [Param DeclType] -> Scope (Rep m)
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams ([Param DeclType] -> Scope (Rep m))
-> [Param DeclType] -> Scope (Rep m)
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge
  Body (Rep m)
loopBody <- Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m)))
-> (Builder (Rep m) (Body (Rep m))
    -> Builder (Rep m) (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m))
-> m (Body (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope (Rep m)
-> Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m))
forall a.
Scope (Rep m)
-> BuilderT (Rep m) (State VNameSource) a
-> BuilderT (Rep m) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope (Rep m)
iter_scope (Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ do
    -- Bind images to parameters of bucket function.
    [SubExp]
imgs' <- [VName]
-> (VName -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BuilderT (Rep m) (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
imgs ((VName -> BuilderT (Rep m) (State VNameSource) SubExp)
 -> BuilderT (Rep m) (State VNameSource) [SubExp])
-> (VName -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BuilderT (Rep m) (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
img -> do
      Type
img_t <- VName -> BuilderT (Rep m) (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
img
      [Char]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"pixel" (Exp (Rep (BuilderT (Rep m) (State VNameSource)))
 -> BuilderT (Rep m) (State VNameSource) SubExp)
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT (Rep m) (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
img (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
img_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
iter]
    [SubExp]
imgs'' <- (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [SubExp])
-> BuilderT (Rep m) (State VNameSource) Result
-> BuilderT (Rep m) (State VNameSource) [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda (Rep (BuilderT (Rep m) (State VNameSource)))
-> [Exp (Rep (BuilderT (Rep m) (State VNameSource)))]
-> BuilderT (Rep m) (State VNameSource) Result
forall (m :: * -> *).
Transformer m =>
Lambda (Rep m) -> [Exp (Rep m)] -> m Result
bindLambda Lambda (Rep m)
Lambda (Rep (BuilderT (Rep m) (State VNameSource)))
bucket_fun ((SubExp -> Exp (Rep m)) -> [SubExp] -> [Exp (Rep m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m))
-> (SubExp -> BasicOp) -> SubExp -> Exp (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) [SubExp]
imgs')

    -- Split out values from bucket function.
    let lens :: Int
lens = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (HistOp (Rep m) -> Int) -> [HistOp (Rep m)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank (ShapeBase SubExp -> Int)
-> (HistOp (Rep m) -> ShapeBase SubExp) -> HistOp (Rep m) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Rep m) -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape) [HistOp (Rep m)]
ops
        ops_inds :: [[SubExp]]
ops_inds = [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp (Rep m) -> Int) -> [HistOp (Rep m)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank (ShapeBase SubExp -> Int)
-> (HistOp (Rep m) -> ShapeBase SubExp) -> HistOp (Rep m) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Rep m) -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape) [HistOp (Rep m)]
ops) (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
lens [SubExp]
imgs'')
        vals :: [[SubExp]]
vals = [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp (Rep m) -> Int) -> [HistOp (Rep m)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int)
-> (HistOp (Rep m) -> [Type]) -> HistOp (Rep m) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Rep m) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda (Rep m) -> [Type])
-> (HistOp (Rep m) -> Lambda (Rep m)) -> HistOp (Rep m) -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Rep m) -> Lambda (Rep m)
forall rep. HistOp rep -> Lambda rep
histOp) [HistOp (Rep m)]
ops) ([SubExp] -> [[SubExp]]) -> [SubExp] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
lens [SubExp]
imgs''
        hists_out' :: [[VName]]
hists_out' =
          [Int] -> [VName] -> [[VName]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp (Rep m) -> Int) -> [HistOp (Rep m)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int)
-> (HistOp (Rep m) -> [Type]) -> HistOp (Rep m) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Rep m) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda (Rep m) -> [Type])
-> (HistOp (Rep m) -> Lambda (Rep m)) -> HistOp (Rep m) -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Rep m) -> Lambda (Rep m)
forall rep. HistOp rep -> Lambda rep
histOp) [HistOp (Rep m)]
ops) ([VName] -> [[VName]]) -> [VName] -> [[VName]]
forall a b. (a -> b) -> a -> b
$
            (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
hists_out

    [[VName]]
hists_out'' <- [([VName], HistOp (Rep m), [SubExp], [SubExp])]
-> (([VName], HistOp (Rep m), [SubExp], [SubExp])
    -> BuilderT (Rep m) (State VNameSource) [VName])
-> BuilderT (Rep m) (State VNameSource) [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([[VName]]
-> [HistOp (Rep m)]
-> [[SubExp]]
-> [[SubExp]]
-> [([VName], HistOp (Rep m), [SubExp], [SubExp])]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [[VName]]
hists_out' [HistOp (Rep m)]
ops [[SubExp]]
ops_inds [[SubExp]]
vals) ((([VName], HistOp (Rep m), [SubExp], [SubExp])
  -> BuilderT (Rep m) (State VNameSource) [VName])
 -> BuilderT (Rep m) (State VNameSource) [[VName]])
-> (([VName], HistOp (Rep m), [SubExp], [SubExp])
    -> BuilderT (Rep m) (State VNameSource) [VName])
-> BuilderT (Rep m) (State VNameSource) [[VName]]
forall a b. (a -> b) -> a -> b
$ \([VName]
hist, HistOp (Rep m)
op, [SubExp]
idxs, [SubExp]
val) -> do
      -- Check whether the indexes are in-bound.  If they are not, we
      -- return the histograms unchanged.
      let outside_bounds_branch :: BuilderT
  (Rep m)
  (State VNameSource)
  (Body (Rep (BuilderT (Rep m) (State VNameSource))))
outside_bounds_branch = BuilderT (Rep m) (State VNameSource) Result
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Body (Rep (BuilderT (Rep m) (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (BuilderT (Rep m) (State VNameSource) Result
 -> BuilderT
      (Rep m)
      (State VNameSource)
      (Body (Rep (BuilderT (Rep m) (State VNameSource)))))
-> BuilderT (Rep m) (State VNameSource) Result
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Body (Rep (BuilderT (Rep m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Result -> BuilderT (Rep m) (State VNameSource) Result
forall a. a -> BuilderT (Rep m) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> BuilderT (Rep m) (State VNameSource) Result)
-> Result -> BuilderT (Rep m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes [VName]
hist
          oob :: BuilderT
  (Rep m)
  (State VNameSource)
  (Exp (Rep (BuilderT (Rep m) (State VNameSource))))
oob = case [VName]
hist of
            [] -> SubExp
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Exp (Rep (BuilderT (Rep m) (State VNameSource))))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp
 -> BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource)))))
-> SubExp
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Exp (Rep (BuilderT (Rep m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
            VName
arr : [VName]
_ -> VName
-> [BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Exp (Rep (BuilderT (Rep m) (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eOutOfBounds VName
arr ([BuilderT
    (Rep m)
    (State VNameSource)
    (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
 -> BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource)))))
-> [BuilderT
      (Rep m)
      (State VNameSource)
      (Exp (Rep (BuilderT (Rep m) (State VNameSource))))]
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Exp (Rep (BuilderT (Rep m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (SubExp -> BuilderT (Rep m) (State VNameSource) (Exp (Rep m)))
-> [SubExp] -> [BuilderT (Rep m) (State VNameSource) (Exp (Rep m))]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> BuilderT (Rep m) (State VNameSource) (Exp (Rep m))
SubExp
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Exp (Rep (BuilderT (Rep m) (State VNameSource))))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp [SubExp]
idxs

      [Char]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"new_histo" (Exp (Rep m) -> BuilderT (Rep m) (State VNameSource) [VName])
-> (Builder (Rep m) (Body (Rep m))
    -> BuilderT (Rep m) (State VNameSource) (Exp (Rep m)))
-> Builder (Rep m) (Body (Rep m))
-> BuilderT (Rep m) (State VNameSource) [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BuilderT
  (Rep m)
  (State VNameSource)
  (Exp (Rep (BuilderT (Rep m) (State VNameSource))))
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Body (Rep (BuilderT (Rep m) (State VNameSource))))
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Body (Rep (BuilderT (Rep m) (State VNameSource))))
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Exp (Rep (BuilderT (Rep m) (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf BuilderT
  (Rep m)
  (State VNameSource)
  (Exp (Rep (BuilderT (Rep m) (State VNameSource))))
oob BuilderT
  (Rep m)
  (State VNameSource)
  (Body (Rep (BuilderT (Rep m) (State VNameSource))))
outside_bounds_branch (Builder (Rep m) (Body (Rep m))
 -> BuilderT (Rep m) (State VNameSource) [VName])
-> Builder (Rep m) (Body (Rep m))
-> BuilderT (Rep m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
        BuilderT (Rep m) (State VNameSource) Result
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Body (Rep (BuilderT (Rep m) (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (BuilderT (Rep m) (State VNameSource) Result
 -> BuilderT
      (Rep m)
      (State VNameSource)
      (Body (Rep (BuilderT (Rep m) (State VNameSource)))))
-> BuilderT (Rep m) (State VNameSource) Result
-> BuilderT
     (Rep m)
     (State VNameSource)
     (Body (Rep (BuilderT (Rep m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ do
          -- Read values from histogram.
          [SubExp]
h_val <- [VName]
-> (VName -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BuilderT (Rep m) (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
hist ((VName -> BuilderT (Rep m) (State VNameSource) SubExp)
 -> BuilderT (Rep m) (State VNameSource) [SubExp])
-> (VName -> BuilderT (Rep m) (State VNameSource) SubExp)
-> BuilderT (Rep m) (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
arr -> do
            Type
arr_t <- VName -> BuilderT (Rep m) (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
            [Char]
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"read_hist" (Exp (Rep (BuilderT (Rep m) (State VNameSource)))
 -> BuilderT (Rep m) (State VNameSource) SubExp)
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (BuilderT (Rep m) (State VNameSource))))
-> BasicOp -> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
idxs

          -- Apply operator.
          Result
h_val' <- Lambda (Rep (BuilderT (Rep m) (State VNameSource)))
-> [Exp (Rep (BuilderT (Rep m) (State VNameSource)))]
-> BuilderT (Rep m) (State VNameSource) Result
forall (m :: * -> *).
Transformer m =>
Lambda (Rep m) -> [Exp (Rep m)] -> m Result
bindLambda (HistOp (Rep m) -> Lambda (Rep m)
forall rep. HistOp rep -> Lambda rep
histOp HistOp (Rep m)
op) ([Exp (Rep (BuilderT (Rep m) (State VNameSource)))]
 -> BuilderT (Rep m) (State VNameSource) Result)
-> [Exp (Rep (BuilderT (Rep m) (State VNameSource)))]
-> BuilderT (Rep m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ (SubExp -> Exp (Rep (BuilderT (Rep m) (State VNameSource))))
-> [SubExp] -> [Exp (Rep (BuilderT (Rep m) (State VNameSource)))]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m))
-> (SubExp -> BasicOp) -> SubExp -> Exp (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) ([SubExp] -> [Exp (Rep (BuilderT (Rep m) (State VNameSource)))])
-> [SubExp] -> [Exp (Rep (BuilderT (Rep m) (State VNameSource)))]
forall a b. (a -> b) -> a -> b
$ [SubExp]
h_val [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val

          -- Write values back to histograms.
          [VName]
hist' <- [(VName, SubExpRes)]
-> ((VName, SubExpRes)
    -> BuilderT (Rep m) (State VNameSource) VName)
-> BuilderT (Rep m) (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> Result -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
hist Result
h_val') (((VName, SubExpRes) -> BuilderT (Rep m) (State VNameSource) VName)
 -> BuilderT (Rep m) (State VNameSource) [VName])
-> ((VName, SubExpRes)
    -> BuilderT (Rep m) (State VNameSource) VName)
-> BuilderT (Rep m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExpRes Certs
cs SubExp
v) -> do
            Type
arr_t <- VName -> BuilderT (Rep m) (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
            Certs
-> BuilderT (Rep m) (State VNameSource) VName
-> BuilderT (Rep m) (State VNameSource) VName
forall a.
Certs
-> BuilderT (Rep m) (State VNameSource) a
-> BuilderT (Rep m) (State VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (BuilderT (Rep m) (State VNameSource) VName
 -> BuilderT (Rep m) (State VNameSource) VName)
-> (Exp (Rep m) -> BuilderT (Rep m) (State VNameSource) VName)
-> Exp (Rep m)
-> BuilderT (Rep m) (State VNameSource) VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char]
-> VName
-> Slice SubExp
-> Exp (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
letInPlace [Char]
"hist_out" VName
arr (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
idxs) (Exp (Rep m) -> BuilderT (Rep m) (State VNameSource) VName)
-> Exp (Rep m) -> BuilderT (Rep m) (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
              BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
                SubExp -> BasicOp
SubExp SubExp
v

          Result -> BuilderT (Rep m) (State VNameSource) Result
forall a. a -> BuilderT (Rep m) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> BuilderT (Rep m) (State VNameSource) Result)
-> Result -> BuilderT (Rep m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes [VName]
hist'

    Body (Rep m) -> Builder (Rep m) (Body (Rep m))
forall a. a -> BuilderT (Rep m) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Rep m) -> Builder (Rep m) (Body (Rep m)))
-> Body (Rep m) -> Builder (Rep m) (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Body (Rep m)
forall rep. Buildable rep => [SubExp] -> Body rep
resultBody ([SubExp] -> Body (Rep m)) -> [SubExp] -> Body (Rep m)
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
hists_out''

  -- Wrap up the above into a for-loop.
  Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep m))
pat (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Rep m), SubExp)]
-> LoopForm -> Body (Rep m) -> Exp (Rep m)
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param DeclType, SubExp)]
[(FParam (Rep m), SubExp)]
merge (VName -> IntType -> SubExp -> LoopForm
ForLoop VName
iter IntType
Int64 SubExp
len) Body (Rep m)
loopBody

-- | Recursively first-order-transform a lambda.
transformLambda ::
  ( MonadFreshNames m,
    Buildable rep,
    BuilderOps rep,
    LocalScope somerep m,
    SameScope somerep rep,
    LetDec rep ~ LetDec SOACS,
    Alias.AliasableRep rep
  ) =>
  Lambda SOACS ->
  m (AST.Lambda rep)
transformLambda :: forall (m :: * -> *) rep somerep.
(MonadFreshNames m, Buildable rep, BuilderOps rep,
 LocalScope somerep m, SameScope somerep rep,
 LetDec rep ~ LetDec SOACS, AliasableRep rep) =>
Lambda SOACS -> m (Lambda rep)
transformLambda (Lambda [Param (LParamInfo SOACS)]
params [Type]
rettype Body SOACS
body) = do
  Body rep
body' <-
    Builder rep (Body rep) -> m (Body rep)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder rep (Body rep) -> m (Body rep))
-> Builder rep (Body rep) -> m (Body rep)
forall a b. (a -> b) -> a -> b
$
      Scope rep -> Builder rep (Body rep) -> Builder rep (Body rep)
forall a.
Scope rep
-> BuilderT rep (State VNameSource) a
-> BuilderT rep (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (LParamInfo rep)] -> Scope rep
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param (LParamInfo rep)]
[Param (LParamInfo SOACS)]
params) (Builder rep (Body rep) -> Builder rep (Body rep))
-> Builder rep (Body rep) -> Builder rep (Body rep)
forall a b. (a -> b) -> a -> b
$
        Body SOACS
-> BuilderT
     rep
     (State VNameSource)
     (Body (Rep (BuilderT rep (State VNameSource))))
forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Body SOACS -> m (Body (Rep m))
transformBody Body SOACS
body
  Lambda rep -> m (Lambda rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep -> m (Lambda rep)) -> Lambda rep -> m (Lambda rep)
forall a b. (a -> b) -> a -> b
$ [Param (LParamInfo rep)] -> [Type] -> Body rep -> Lambda rep
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda [Param (LParamInfo rep)]
[Param (LParamInfo SOACS)]
params [Type]
rettype Body rep
body'

letwith :: (Transformer m) => [VName] -> SubExp -> [SubExp] -> m [VName]
letwith :: forall (m :: * -> *).
Transformer m =>
[VName] -> SubExp -> [SubExp] -> m [VName]
letwith [VName]
ks SubExp
i [SubExp]
vs = do
  let update :: VName -> SubExp -> m VName
update VName
k SubExp
v = do
        Type
k_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
k
        case Type
k_t of
          Acc {} ->
            [Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"lw_acc" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v
          Type
_ ->
            [Char] -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
letInPlace [Char]
"lw_dest" VName
k (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
k_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i]) (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v
  (VName -> SubExp -> m VName) -> [VName] -> [SubExp] -> m [VName]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM VName -> SubExp -> m VName
update [VName]
ks [SubExp]
vs

bindLambda ::
  (Transformer m) =>
  AST.Lambda (Rep m) ->
  [AST.Exp (Rep m)] ->
  m Result
bindLambda :: forall (m :: * -> *).
Transformer m =>
Lambda (Rep m) -> [Exp (Rep m)] -> m Result
bindLambda (Lambda [LParam (Rep m)]
params [Type]
_ Body (Rep m)
body) [Exp (Rep m)]
args = do
  [(Param Type, Exp (Rep m))]
-> ((Param Type, Exp (Rep m)) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [Exp (Rep m)] -> [(Param Type, Exp (Rep m))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
[LParam (Rep m)]
params [Exp (Rep m)]
args) (((Param Type, Exp (Rep m)) -> m ()) -> m ())
-> ((Param Type, Exp (Rep m)) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
param, Exp (Rep m)
arg) ->
    if Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
param
      then [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
param] Exp (Rep m)
arg
      else [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
param] (Exp (Rep m) -> m ()) -> m (Exp (Rep m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
m (Exp (Rep m)) -> m (Exp (Rep m))
eCopy (Exp (Rep m) -> m (Exp (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp (Rep m)
arg)
  Body (Rep m) -> m Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body (Rep m)
body

loopMerge :: [Ident] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge :: [Ident] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge [Ident]
vars = [(Ident, Uniqueness)] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge' ([(Ident, Uniqueness)] -> [SubExp] -> [(Param DeclType, SubExp)])
-> [(Ident, Uniqueness)] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (Ident -> (Ident, Uniqueness)) -> [Ident] -> [(Ident, Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (,Uniqueness
Unique) [Ident]
vars

loopMerge' :: [(Ident, Uniqueness)] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge' :: [(Ident, Uniqueness)] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge' [(Ident, Uniqueness)]
vars [SubExp]
vals =
  [ (Attrs -> VName -> DeclType -> Param DeclType
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
pname (DeclType -> Param DeclType) -> DeclType -> Param DeclType
forall a b. (a -> b) -> a -> b
$ Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
ptype Uniqueness
u, SubExp
val)
    | ((Ident VName
pname Type
ptype, Uniqueness
u), SubExp
val) <- [(Ident, Uniqueness)]
-> [SubExp] -> [((Ident, Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Ident, Uniqueness)]
vars [SubExp]
vals
  ]

-- Note [Translation of Screma]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
--
-- Screma is the most general SOAC.  It is translated by constructing
-- a loop that contains several groups of parameters, in this order:
--
-- (0) Scan accumulator, initialised with neutral element.
-- (1) Scan results, initialised with Scratch.
-- (2) Reduce results (also functioning as accumulators),
--     initialised with neutral element.
-- (3) Map results, mostly initialised with Scratch.
--
-- However, category (3) is a little more tricky in the case where one
-- of the results is an Acc.  In that case, the result is not an
-- array, but another Acc.  Any Acc result of a Map must correspond to
-- an Acc that is an input to the map, and the result is initialised
-- to be that input.  This requires a 1:1 relationship between Acc
-- inputs and Acc outputs, which the type checker should enforce.
-- There is no guarantee that the map results appear in any particular
-- order (e.g. accumulator results before non-accumulator results), so
-- we need to do a little sleuthing to establish the relationship.
--
-- Inside the loop, the non-Acc parameters to map_lam become for-in
-- parameters.  Acc parameters refer to the loop parameters for the
-- corresponding Map result instead.
--
-- Intuitively, a Screma(w,
--                       (scan_op, scan_ne),
--                       (red_op, red_ne),
--                       map_fn,
--                       {acc_input, arr_input})
--
-- then becomes
--
-- loop (scan_acc, scan_arr, red_acc, map_acc, map_arr) =
--   for i < w, x in arr_input do
--     let (a,b,map_acc',d) = map_fn(map_acc, x)
--     let scan_acc' = scan_op(scan_acc, a)
--     let scan_arr[i] = scan_acc'
--     let red_acc' = red_op(red_acc, b)
--     let map_arr[i] = d
--     in (scan_acc', scan_arr', red_acc', map_acc', map_arr)