{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | The code generator cannot handle the array combinators (@map@ and
-- friends), so this module was written to transform them into the
-- equivalent do-loops.  The transformation is currently rather naive,
-- and - it's certainly worth considering when we can express such
-- transformations in-place.
module Futhark.Transform.FirstOrderTransform
  ( transformFunDef,
    transformConsts,
    FirstOrderLore,
    Transformer,
    transformStmRecursively,
    transformLambda,
    transformSOAC,
  )
where

import Control.Monad.Except
import Control.Monad.State
import Data.List (zip4)
import qualified Data.Map.Strict as M
import qualified Futhark.IR as AST
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Util (chunks, splitAt3)

-- | The constraints that must hold for a lore in order to be the
-- target of first-order transformation.
type FirstOrderLore lore =
  ( Bindable lore,
    BinderOps lore,
    LetDec SOACS ~ LetDec lore,
    LParamInfo SOACS ~ LParamInfo lore
  )

-- | First-order-transform a single function, with the given scope
-- provided by top-level constants.
transformFunDef ::
  (MonadFreshNames m, FirstOrderLore tolore) =>
  Scope tolore ->
  FunDef SOACS ->
  m (AST.FunDef tolore)
transformFunDef :: Scope tolore -> FunDef SOACS -> m (FunDef tolore)
transformFunDef Scope tolore
consts_scope (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType SOACS]
rettype [FParam SOACS]
params BodyT SOACS
body) = do
  (BodyT tolore
body', Stms tolore
_) <- (VNameSource -> ((BodyT tolore, Stms tolore), VNameSource))
-> m (BodyT tolore, Stms tolore)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((BodyT tolore, Stms tolore), VNameSource))
 -> m (BodyT tolore, Stms tolore))
-> (VNameSource -> ((BodyT tolore, Stms tolore), VNameSource))
-> m (BodyT tolore, Stms tolore)
forall a b. (a -> b) -> a -> b
$ State VNameSource (BodyT tolore, Stms tolore)
-> VNameSource -> ((BodyT tolore, Stms tolore), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (BodyT tolore, Stms tolore)
 -> VNameSource -> ((BodyT tolore, Stms tolore), VNameSource))
-> State VNameSource (BodyT tolore, Stms tolore)
-> VNameSource
-> ((BodyT tolore, Stms tolore), VNameSource)
forall a b. (a -> b) -> a -> b
$ BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
-> Scope tolore -> State VNameSource (BodyT tolore, Stms tolore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
m Scope tolore
consts_scope
  FunDef tolore -> m (FunDef tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef tolore -> m (FunDef tolore))
-> FunDef tolore -> m (FunDef tolore)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [RetType tolore]
-> [FParam tolore]
-> BodyT tolore
-> FunDef tolore
forall lore.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType lore]
-> [FParam lore]
-> BodyT lore
-> FunDef lore
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType tolore]
[RetType SOACS]
rettype [FParam tolore]
[FParam SOACS]
params BodyT tolore
body'
  where
    m :: BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
m = Scope tolore
-> BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
-> BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param DeclType] -> Scope tolore
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param DeclType]
[FParam SOACS]
params) (BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
 -> BinderT tolore (StateT VNameSource Identity) (BodyT tolore))
-> BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
-> BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
forall a b. (a -> b) -> a -> b
$ BinderT
  tolore
  (StateT VNameSource Identity)
  (Body (Lore (BinderT tolore (StateT VNameSource Identity))))
-> BinderT
     tolore
     (StateT VNameSource Identity)
     (Body (Lore (BinderT tolore (StateT VNameSource Identity))))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (BinderT
   tolore
   (StateT VNameSource Identity)
   (Body (Lore (BinderT tolore (StateT VNameSource Identity))))
 -> BinderT
      tolore
      (StateT VNameSource Identity)
      (Body (Lore (BinderT tolore (StateT VNameSource Identity)))))
-> BinderT
     tolore
     (StateT VNameSource Identity)
     (Body (Lore (BinderT tolore (StateT VNameSource Identity))))
-> BinderT
     tolore
     (StateT VNameSource Identity)
     (Body (Lore (BinderT tolore (StateT VNameSource Identity))))
forall a b. (a -> b) -> a -> b
$ BodyT SOACS
-> BinderT
     tolore
     (StateT VNameSource Identity)
     (Body (Lore (BinderT tolore (StateT VNameSource Identity))))
forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
BodyT SOACS -> m (Body (Lore m))
transformBody BodyT SOACS
body

-- | First-order-transform these top-level constants.
transformConsts ::
  (MonadFreshNames m, FirstOrderLore tolore) =>
  Stms SOACS ->
  m (AST.Stms tolore)
transformConsts :: Stms SOACS -> m (Stms tolore)
transformConsts Stms SOACS
stms =
  (((), Stms tolore) -> Stms tolore)
-> m ((), Stms tolore) -> m (Stms tolore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Stms tolore) -> Stms tolore
forall a b. (a, b) -> b
snd (m ((), Stms tolore) -> m (Stms tolore))
-> m ((), Stms tolore) -> m (Stms tolore)
forall a b. (a -> b) -> a -> b
$ (VNameSource -> (((), Stms tolore), VNameSource))
-> m ((), Stms tolore)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (((), Stms tolore), VNameSource))
 -> m ((), Stms tolore))
-> (VNameSource -> (((), Stms tolore), VNameSource))
-> m ((), Stms tolore)
forall a b. (a -> b) -> a -> b
$ State VNameSource ((), Stms tolore)
-> VNameSource -> (((), Stms tolore), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource ((), Stms tolore)
 -> VNameSource -> (((), Stms tolore), VNameSource))
-> State VNameSource ((), Stms tolore)
-> VNameSource
-> (((), Stms tolore), VNameSource)
forall a b. (a -> b) -> a -> b
$ BinderT tolore (StateT VNameSource Identity) ()
-> Scope tolore -> State VNameSource ((), Stms tolore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT tolore (StateT VNameSource Identity) ()
m Scope tolore
forall a. Monoid a => a
mempty
  where
    m :: BinderT tolore (StateT VNameSource Identity) ()
m = (Stm -> BinderT tolore (StateT VNameSource Identity) ())
-> Stms SOACS -> BinderT tolore (StateT VNameSource Identity) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm -> BinderT tolore (StateT VNameSource Identity) ()
forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
Stm -> m ()
transformStmRecursively Stms SOACS
stms

-- | The constraints that a monad must uphold in order to be used for
-- first-order transformation.
type Transformer m =
  ( MonadBinder m,
    LocalScope (Lore m) m,
    Bindable (Lore m),
    BinderOps (Lore m),
    LParamInfo SOACS ~ LParamInfo (Lore m)
  )

transformBody ::
  (Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
  Body ->
  m (AST.Body (Lore m))
transformBody :: BodyT SOACS -> m (Body (Lore m))
transformBody (Body () Stms SOACS
bnds Result
res) = m (Body (Lore m)) -> m (Body (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (m (Body (Lore m)) -> m (Body (Lore m)))
-> m (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
  (Stm -> m ()) -> Stms SOACS -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm -> m ()
forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
Stm -> m ()
transformStmRecursively Stms SOACS
bnds
  Body (Lore m) -> m (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> m (Body (Lore m)))
-> Body (Lore m) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ Result -> Body (Lore m)
forall lore. Bindable lore => Result -> Body lore
resultBody Result
res

-- | First transform any nested t'Body' or t'Lambda' elements, then
-- apply 'transformSOAC' if the expression is a SOAC.
transformStmRecursively ::
  (Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
  Stm ->
  m ()
transformStmRecursively :: Stm -> m ()
transformStmRecursively (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac)) =
  StmAux () -> m () -> m ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m) -> SOAC (Lore m) -> m ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
transformSOAC Pattern (Lore m)
Pattern SOACS
pat (SOAC (Lore m) -> m ()) -> m (SOAC (Lore m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOACMapper SOACS (Lore m) m -> SOAC SOACS -> m (SOAC (Lore m))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper SOACS (Lore m) m
soacTransform Op SOACS
SOAC SOACS
soac
  where
    soacTransform :: SOACMapper SOACS (Lore m) m
soacTransform = SOACMapper Any Any m
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper {mapOnSOACLambda :: Lambda SOACS -> m (Lambda (Lore m))
mapOnSOACLambda = Lambda SOACS -> m (Lambda (Lore m))
forall (m :: * -> *) lore somelore.
(MonadFreshNames m, Bindable lore, BinderOps lore,
 LocalScope somelore m, SameScope somelore lore,
 LetDec lore ~ LetDec SOACS) =>
Lambda SOACS -> m (Lambda lore)
transformLambda}
transformStmRecursively (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux ExpT SOACS
e) =
  StmAux () -> m () -> m ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m) -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
Pattern SOACS
pat (ExpT (Lore m) -> m ()) -> m (ExpT (Lore m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Mapper SOACS (Lore m) m -> ExpT SOACS -> m (ExpT (Lore m))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper SOACS (Lore m) m
transform ExpT SOACS
e
  where
    transform :: Mapper SOACS (Lore m) m
transform =
      Mapper Any Any m
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
        { mapOnBody :: Scope (Lore m) -> BodyT SOACS -> m (Body (Lore m))
mapOnBody = \Scope (Lore m)
scope -> Scope (Lore m) -> m (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope (Lore m)
scope (m (Body (Lore m)) -> m (Body (Lore m)))
-> (BodyT SOACS -> m (Body (Lore m)))
-> BodyT SOACS
-> m (Body (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyT SOACS -> m (Body (Lore m))
forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
BodyT SOACS -> m (Body (Lore m))
transformBody,
          mapOnRetType :: RetType SOACS -> m (RetType (Lore m))
mapOnRetType = RetType SOACS -> m (RetType (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return,
          mapOnBranchType :: BranchType SOACS -> m (BranchType (Lore m))
mapOnBranchType = BranchType SOACS -> m (BranchType (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return,
          mapOnFParam :: FParam SOACS -> m (FParam (Lore m))
mapOnFParam = FParam SOACS -> m (FParam (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return,
          mapOnLParam :: LParam SOACS -> m (LParam (Lore m))
mapOnLParam = LParam SOACS -> m (LParam (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return,
          mapOnOp :: Op SOACS -> m (Op (Lore m))
mapOnOp = [Char] -> SOAC SOACS -> m (Op (Lore m))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled Op in first order transform"
        }

-- | Transform a single 'SOAC' into a do-loop.  The body of the lambda
-- is untouched, and may or may not contain further 'SOAC's depending
-- on the given lore.
transformSOAC ::
  Transformer m =>
  AST.Pattern (Lore m) ->
  SOAC (Lore m) ->
  m ()
transformSOAC :: Pattern (Lore m) -> SOAC (Lore m) -> m ()
transformSOAC Pattern (Lore m)
pat (Screma SubExp
w form :: ScremaForm (Lore m)
form@(ScremaForm [Scan (Lore m)]
scans [Reduce (Lore m)]
reds Lambda (Lore m)
map_lam) [VName]
arrs) = do
  -- Start by combining all the reduction parts into a single operator
  let Reduce Commutativity
_ Lambda (Lore m)
red_lam Result
red_nes = [Reduce (Lore m)] -> Reduce (Lore m)
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce (Lore m)]
reds
      Scan Lambda (Lore m)
scan_lam Result
scan_nes = [Scan (Lore m)] -> Scan (Lore m)
forall lore. Bindable lore => [Scan lore] -> Scan lore
singleScan [Scan (Lore m)]
scans
      ([Type]
scan_arr_ts, [Type]
_red_ts, [Type]
map_arr_ts) =
        Int -> Int -> [Type] -> ([Type], [Type], [Type])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
scan_nes) (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
red_nes) ([Type] -> ([Type], [Type], [Type]))
-> [Type] -> ([Type], [Type], [Type])
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm (Lore m) -> [Type]
forall lore. SubExp -> ScremaForm lore -> [Type]
scremaType SubExp
w ScremaForm (Lore m)
form
  [VName]
scan_arrs <- [Type] -> m [VName]
forall (m :: * -> *). Transformer m => [Type] -> m [VName]
resultArray [Type]
scan_arr_ts
  [VName]
map_arrs <- [Type] -> m [VName]
forall (m :: * -> *). Transformer m => [Type] -> m [VName]
resultArray [Type]
map_arr_ts

  -- We construct a loop that contains several groups of merge
  -- parameters:
  --
  -- (0) scan accumulator.
  -- (1) scan results.
  -- (2) reduce results (and accumulator).
  -- (3) map results.
  --
  -- Inside the loop, the parameters to map_lam become for-in
  -- parameters.

  [Param DeclType]
scanacc_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"scanacc" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) ([Type] -> m [Param DeclType]) -> [Type] -> m [Param DeclType]
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
scan_lam
  [Param DeclType]
scanout_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"scanout" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Unique) [Type]
scan_arr_ts
  [Param DeclType]
redout_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"redout" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) ([Type] -> m [Param DeclType]) -> [Type] -> m [Param DeclType]
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
red_lam
  [Param DeclType]
mapout_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"mapout" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Unique) [Type]
map_arr_ts

  let merge :: [(Param DeclType, SubExp)]
merge =
        [[(Param DeclType, SubExp)]] -> [(Param DeclType, SubExp)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
          [ [Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
scanacc_params Result
scan_nes,
            [Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
scanout_params (Result -> [(Param DeclType, SubExp)])
-> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
scan_arrs,
            [Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
redout_params Result
red_nes,
            [Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
mapout_params (Result -> [(Param DeclType, SubExp)])
-> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
map_arrs
          ]
  VName
i <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i"
  let loopform :: LoopForm (Lore m)
loopform = VName
-> IntType
-> SubExp
-> [(LParam (Lore m), VName)]
-> LoopForm (Lore m)
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
Int64 SubExp
w []

  Body (Lore m)
loop_body <- Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
    Scope (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param DeclType] -> Scope (Lore m)
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams ([Param DeclType] -> Scope (Lore m))
-> [Param DeclType] -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge) (Binder (Lore m) (Body (Lore m))
 -> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
      LoopForm (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf LoopForm (Lore m)
loopform (Binder (Lore m) (Body (Lore m))
 -> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
        [(Param Type, VName)]
-> ((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda (Lore m) -> [LParam (Lore m)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Lore m)
map_lam) [VName]
arrs) (((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
 -> BinderT (Lore m) (State VNameSource) ())
-> ((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) -> do
          Type
arr_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
          [VName]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Lore (BinderT (Lore m) (State VNameSource)))
 -> BinderT (Lore m) (State VNameSource) ())
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$
              VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
                Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i]

        -- Insert the statements of the lambda.  We have taken care to
        -- ensure that the parameters are bound at this point.
        (Stm (Lore m) -> BinderT (Lore m) (State VNameSource) ())
-> Seq (Stm (Lore m)) -> BinderT (Lore m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm (Lore m) -> BinderT (Lore m) (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Seq (Stm (Lore m)) -> BinderT (Lore m) (State VNameSource) ())
-> Seq (Stm (Lore m)) -> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Body (Lore m) -> Seq (Stm (Lore m))
forall lore. BodyT lore -> Stms lore
bodyStms (Body (Lore m) -> Seq (Stm (Lore m)))
-> Body (Lore m) -> Seq (Stm (Lore m))
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> Body (Lore m)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
map_lam
        -- Split into scan results, reduce results, and map results.
        let (Result
scan_res, Result
red_res, Result
map_res) =
              Int -> Int -> Result -> (Result, Result, Result)
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
scan_nes) (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
red_nes) (Result -> (Result, Result, Result))
-> Result -> (Result, Result, Result)
forall a b. (a -> b) -> a -> b
$
                Body (Lore m) -> Result
forall lore. BodyT lore -> Result
bodyResult (Body (Lore m) -> Result) -> Body (Lore m) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> Body (Lore m)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
map_lam

        Result
scan_res' <-
          Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [BinderT
      (Lore m)
      (State VNameSource)
      (Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
MonadBinder m =>
Lambda (Lore m) -> [m (Exp (Lore m))] -> m Result
eLambda Lambda (Lore m)
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
scan_lam ([BinderT
    (Lore m)
    (State VNameSource)
    (Exp (Lore (BinderT (Lore m) (State VNameSource))))]
 -> BinderT (Lore m) (State VNameSource) Result)
-> [BinderT
      (Lore m)
      (State VNameSource)
      (Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
            (SubExp -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))]
forall a b. (a -> b) -> [a] -> [b]
map (ExpT (Lore m)
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT (Lore m)
 -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> (SubExp -> ExpT (Lore m))
-> SubExp
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) (Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))])
-> Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))]
forall a b. (a -> b) -> a -> b
$
              (Param DeclType -> SubExp) -> [Param DeclType] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param DeclType -> VName) -> Param DeclType -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> VName
forall dec. Param dec -> VName
paramName) [Param DeclType]
scanacc_params Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
scan_res
        Result
red_res' <-
          Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [BinderT
      (Lore m)
      (State VNameSource)
      (Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
MonadBinder m =>
Lambda (Lore m) -> [m (Exp (Lore m))] -> m Result
eLambda Lambda (Lore m)
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
red_lam ([BinderT
    (Lore m)
    (State VNameSource)
    (Exp (Lore (BinderT (Lore m) (State VNameSource))))]
 -> BinderT (Lore m) (State VNameSource) Result)
-> [BinderT
      (Lore m)
      (State VNameSource)
      (Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
            (SubExp -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))]
forall a b. (a -> b) -> [a] -> [b]
map (ExpT (Lore m)
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT (Lore m)
 -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> (SubExp -> ExpT (Lore m))
-> SubExp
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) (Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))])
-> Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))]
forall a b. (a -> b) -> a -> b
$
              (Param DeclType -> SubExp) -> [Param DeclType] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param DeclType -> VName) -> Param DeclType -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> VName
forall dec. Param dec -> VName
paramName) [Param DeclType]
redout_params Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
red_res

        -- Write the scan accumulator to the scan result arrays.
        [VName]
scan_outarrs <-
          [VName]
-> BinderT
     (Lore m)
     (State VNameSource)
     (Exp (Lore (BinderT (Lore m) (State VNameSource))))
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) [VName]
forall (m :: * -> *).
Transformer m =>
[VName] -> m (Exp (Lore m)) -> [Exp (Lore m)] -> m [VName]
letwith ((Param DeclType -> VName) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> VName
forall dec. Param dec -> VName
paramName [Param DeclType]
scanout_params) (SubExp -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall (f :: * -> *) lore. Applicative f => SubExp -> f (Exp lore)
pexp (VName -> SubExp
Var VName
i)) ([Exp (Lore (BinderT (Lore m) (State VNameSource)))]
 -> BinderT (Lore m) (State VNameSource) [VName])
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
            (SubExp -> ExpT (Lore m)) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) Result
scan_res'

        -- Write the map results to the map result arrays.
        [VName]
map_outarrs <-
          [VName]
-> BinderT
     (Lore m)
     (State VNameSource)
     (Exp (Lore (BinderT (Lore m) (State VNameSource))))
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) [VName]
forall (m :: * -> *).
Transformer m =>
[VName] -> m (Exp (Lore m)) -> [Exp (Lore m)] -> m [VName]
letwith ((Param DeclType -> VName) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> VName
forall dec. Param dec -> VName
paramName [Param DeclType]
mapout_params) (SubExp -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall (f :: * -> *) lore. Applicative f => SubExp -> f (Exp lore)
pexp (VName -> SubExp
Var VName
i)) ([Exp (Lore (BinderT (Lore m) (State VNameSource)))]
 -> BinderT (Lore m) (State VNameSource) [VName])
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
            (SubExp -> ExpT (Lore m)) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) Result
map_res

        Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> Binder (Lore m) (Body (Lore m)))
-> Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
          Result -> Body (Lore m)
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> Body (Lore m)) -> Result -> Body (Lore m)
forall a b. (a -> b) -> a -> b
$
            [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
              [ Result
scan_res',
                (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
scan_outarrs,
                Result
red_res',
                (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
map_outarrs
              ]

  -- We need to discard the final scan accumulators, as they are not
  -- bound in the original pattern.
  [VName]
names <-
    ([VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ Pattern (Lore m) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern (Lore m)
pat)
      ([VName] -> [VName]) -> m [VName] -> m [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([Param DeclType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param DeclType]
scanacc_params) ([Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"discard")
  [VName] -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName]
names (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Lore m), SubExp)]
-> [(FParam (Lore m), SubExp)]
-> LoopForm (Lore m)
-> Body (Lore m)
-> ExpT (Lore m)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam (Lore m), SubExp)]
merge LoopForm (Lore m)
loopform Body (Lore m)
loop_body
transformSOAC Pattern (Lore m)
pat (Stream SubExp
w StreamForm (Lore m)
stream_form Lambda (Lore m)
lam [VName]
arrs) = do
  -- Create a loop that repeatedly applies the lambda body to a
  -- chunksize of 1.  Hopefully this will lead to this outer loop
  -- being the only one, as all the innermost one can be simplified
  -- array (as they will have one iteration each).
  let nes :: Result
nes = StreamForm (Lore m) -> Result
forall lore. StreamForm lore -> Result
getStreamAccums StreamForm (Lore m)
stream_form
      (Param Type
chunk_size_param, [Param Type]
fold_params, [Param Type]
chunk_params) =
        Int -> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
nes) ([Param Type] -> (Param Type, [Param Type], [Param Type]))
-> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> [LParam (Lore m)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Lore m)
lam

  [(Param DeclType, SubExp)]
mapout_merge <- [Type]
-> (Type -> m (Param DeclType, SubExp))
-> m [(Param DeclType, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
nes) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
lam) ((Type -> m (Param DeclType, SubExp))
 -> m [(Param DeclType, SubExp)])
-> (Type -> m (Param DeclType, SubExp))
-> m [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ \Type
t ->
    let t' :: Type
t' = Type
t Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
        scratch :: ExpT (Lore m)
scratch = BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ PrimType -> Result -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t') (Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t')
     in (,)
          (Param DeclType -> SubExp -> (Param DeclType, SubExp))
-> m (Param DeclType) -> m (SubExp -> (Param DeclType, SubExp))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"stream_mapout" (Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
t' Uniqueness
Unique)
          m (SubExp -> (Param DeclType, SubExp))
-> m SubExp -> m (Param DeclType, SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Char] -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"stream_mapout_scratch" ExpT (Lore m)
scratch

  let merge :: [(Param DeclType, SubExp)]
merge =
        [Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param Type -> Param DeclType) -> [Param Type] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> DeclType) -> Param Type -> Param DeclType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
`toDecl` Uniqueness
Nonunique)) [Param Type]
fold_params) Result
nes
          [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
mapout_merge
      merge_params :: [Param DeclType]
merge_params = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge
      mapout_params :: [Param DeclType]
mapout_params = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
mapout_merge

  VName
i <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i"

  let loop_form :: LoopForm (Lore m)
loop_form = VName
-> IntType
-> SubExp
-> [(LParam (Lore m), VName)]
-> LoopForm (Lore m)
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
Int64 SubExp
w []

  [VName] -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size_param] (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$
    BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1

  Body (Lore m)
loop_body <- Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
    Scope (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope
      ( LoopForm (Lore m) -> Scope (Lore m)
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm (Lore m)
loop_form
          Scope (Lore m) -> Scope (Lore m) -> Scope (Lore m)
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope (Lore m)
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param DeclType]
merge_params
      )
      (Binder (Lore m) (Body (Lore m))
 -> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
        let slice :: Slice SubExp
slice =
              [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (VName -> SubExp
Var VName
i) (VName -> SubExp
Var (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size_param)) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
        [(Param Type, VName)]
-> ((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
chunk_params [VName]
arrs) (((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
 -> BinderT (Lore m) (State VNameSource) ())
-> ((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) ->
          [VName]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Lore (BinderT (Lore m) (State VNameSource)))
 -> BinderT (Lore m) (State VNameSource) ())
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$
              VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
                Type -> Slice SubExp -> Slice SubExp
fullSlice (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) Slice SubExp
slice

        (Result
res, Result
mapout_res) <- Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
nes) (Result -> (Result, Result))
-> BinderT (Lore m) (State VNameSource) Result
-> BinderT (Lore m) (State VNameSource) (Result, Result)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m Result
bodyBind (Lambda (Lore m) -> Body (Lore m)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
lam)

        Result
mapout_res' <- [(Param DeclType, SubExp)]
-> ((Param DeclType, SubExp)
    -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
mapout_params Result
mapout_res) (((Param DeclType, SubExp)
  -> BinderT (Lore m) (State VNameSource) SubExp)
 -> BinderT (Lore m) (State VNameSource) Result)
-> ((Param DeclType, SubExp)
    -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \(Param DeclType
p, SubExp
se) ->
          [Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"mapout_res" (Exp (Lore (BinderT (Lore m) (State VNameSource)))
 -> BinderT (Lore m) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$
              VName -> Slice SubExp -> SubExp -> BasicOp
Update
                (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p)
                (Type -> Slice SubExp -> Slice SubExp
fullSlice (Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param DeclType
p) Slice SubExp
slice)
                SubExp
se

        Result
-> BinderT
     (Lore m)
     (State VNameSource)
     (Body (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result
 -> BinderT
      (Lore m)
      (State VNameSource)
      (Body (Lore (BinderT (Lore m) (State VNameSource)))))
-> Result
-> BinderT
     (Lore m)
     (State VNameSource)
     (Body (Lore (BinderT (Lore m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Result
res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
mapout_res'

  Pattern (Lore m) -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
pat (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Lore m), SubExp)]
-> [(FParam (Lore m), SubExp)]
-> LoopForm (Lore m)
-> Body (Lore m)
-> ExpT (Lore m)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam (Lore m), SubExp)]
merge LoopForm (Lore m)
loop_form Body (Lore m)
loop_body
transformSOAC Pattern (Lore m)
pat (Scatter SubExp
len Lambda (Lore m)
lam [VName]
ivs [(SubExp, Int, VName)]
as) = do
  VName
iter <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"write_iter"

  let (Result
_as_ws, [Int]
as_ns, [VName]
as_vs) = [(SubExp, Int, VName)] -> (Result, [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, Int, VName)]
as
  [Type]
ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
as_vs
  [Ident]
asOuts <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent [Char]
"write_out") [Type]
ts

  let ivsLen :: Int
ivsLen = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
lam) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2

  -- Scatter is in-place, so we use the input array as the output array.
  let merge :: [(Param DeclType, SubExp)]
merge = [Ident] -> Result -> [(Param DeclType, SubExp)]
loopMerge [Ident]
asOuts (Result -> [(Param DeclType, SubExp)])
-> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
as_vs
  Body (Lore m)
loopBody <- Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
    Scope (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope
      ( VName -> NameInfo (Lore m) -> Scope (Lore m) -> Scope (Lore m)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
iter (IntType -> NameInfo (Lore m)
forall lore. IntType -> NameInfo lore
IndexName IntType
Int64) (Scope (Lore m) -> Scope (Lore m))
-> Scope (Lore m) -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$
          [Param DeclType] -> Scope (Lore m)
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams ([Param DeclType] -> Scope (Lore m))
-> [Param DeclType] -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge
      )
      (Binder (Lore m) (Body (Lore m))
 -> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
        Result
ivs' <- [VName]
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ivs ((VName -> BinderT (Lore m) (State VNameSource) SubExp)
 -> BinderT (Lore m) (State VNameSource) Result)
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
iv -> do
          Type
iv_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
iv
          [Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"write_iv" (Exp (Lore (BinderT (Lore m) (State VNameSource)))
 -> BinderT (Lore m) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
iv (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
iv_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
iter]
        Result
ivs'' <- Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
Transformer m =>
Lambda (Lore m) -> [Exp (Lore m)] -> m Result
bindLambda Lambda (Lore m)
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
lam ((SubExp -> ExpT (Lore m)) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) Result
ivs')

        let indexes :: [Result]
indexes = [Int] -> Result -> [Result]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
take Int
ivsLen Result
ivs''
            values :: [Result]
values = [Int] -> Result -> [Result]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
ivsLen Result
ivs''

        [VName]
ress <- [(Result, Result, VName)]
-> ((Result, Result, VName)
    -> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Result] -> [Result] -> [VName] -> [(Result, Result, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Result]
indexes [Result]
values ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
asOuts)) (((Result, Result, VName)
  -> BinderT (Lore m) (State VNameSource) VName)
 -> BinderT (Lore m) (State VNameSource) [VName])
-> ((Result, Result, VName)
    -> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(Result
indexes', Result
values', VName
arr) -> do
          let saveInArray :: VName -> (SubExp, SubExp) -> m VName
saveInArray VName
arr' (SubExp
indexCur, SubExp
valueCur) =
                [Char] -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m VName
letExp [Char]
"write_out" (Exp (Lore m) -> m VName) -> m (Exp (Lore m)) -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [m (Exp (Lore m))] -> m (Exp (Lore m)) -> m (Exp (Lore m))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
VName -> [m (Exp (Lore m))] -> m (Exp (Lore m)) -> m (Exp (Lore m))
eWriteArray VName
arr' [SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
indexCur] (SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
valueCur)

          (VName
 -> (SubExp, SubExp) -> BinderT (Lore m) (State VNameSource) VName)
-> VName
-> [(SubExp, SubExp)]
-> BinderT (Lore m) (State VNameSource) VName
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM VName
-> (SubExp, SubExp) -> BinderT (Lore m) (State VNameSource) VName
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
VName -> (SubExp, SubExp) -> m VName
saveInArray VName
arr ([(SubExp, SubExp)] -> BinderT (Lore m) (State VNameSource) VName)
-> [(SubExp, SubExp)] -> BinderT (Lore m) (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ Result -> Result -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
indexes' Result
values'
        Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> Binder (Lore m) (Body (Lore m)))
-> Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ Result -> Body (Lore m)
forall lore. Bindable lore => Result -> Body lore
resultBody ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
ress)
  Pattern (Lore m) -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
pat (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Lore m), SubExp)]
-> [(FParam (Lore m), SubExp)]
-> LoopForm (Lore m)
-> Body (Lore m)
-> ExpT (Lore m)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam (Lore m), SubExp)]
merge (VName
-> IntType
-> SubExp
-> [(LParam (Lore m), VName)]
-> LoopForm (Lore m)
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
iter IntType
Int64 SubExp
len []) Body (Lore m)
loopBody
transformSOAC Pattern (Lore m)
pat (Hist SubExp
len [HistOp (Lore m)]
ops Lambda (Lore m)
bucket_fun [VName]
imgs) = do
  VName
iter <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"iter"

  -- Bind arguments to parameters for the merge-variables.
  [Type]
hists_ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType ([VName] -> m [Type]) -> [VName] -> m [Type]
forall a b. (a -> b) -> a -> b
$ (HistOp (Lore m) -> [VName]) -> [HistOp (Lore m)] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp (Lore m) -> [VName]
forall lore. HistOp lore -> [VName]
histDest [HistOp (Lore m)]
ops
  [Ident]
hists_out <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent [Char]
"dests") [Type]
hists_ts
  let merge :: [(Param DeclType, SubExp)]
merge = [Ident] -> Result -> [(Param DeclType, SubExp)]
loopMerge [Ident]
hists_out (Result -> [(Param DeclType, SubExp)])
-> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (HistOp (Lore m) -> Result) -> [HistOp (Lore m)] -> Result
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> Result)
-> (HistOp (Lore m) -> [VName]) -> HistOp (Lore m) -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Lore m) -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp (Lore m)]
ops

  -- Bind lambda-bodies for operators.
  Body (Lore m)
loopBody <- Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
    Scope (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope
      ( VName -> NameInfo (Lore m) -> Scope (Lore m) -> Scope (Lore m)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
iter (IntType -> NameInfo (Lore m)
forall lore. IntType -> NameInfo lore
IndexName IntType
Int64) (Scope (Lore m) -> Scope (Lore m))
-> Scope (Lore m) -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$
          [Param DeclType] -> Scope (Lore m)
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams ([Param DeclType] -> Scope (Lore m))
-> [Param DeclType] -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge
      )
      (Binder (Lore m) (Body (Lore m))
 -> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
        -- Bind images to parameters of bucket function.
        Result
imgs' <- [VName]
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
imgs ((VName -> BinderT (Lore m) (State VNameSource) SubExp)
 -> BinderT (Lore m) (State VNameSource) Result)
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
img -> do
          Type
img_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
img
          [Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"pixel" (Exp (Lore (BinderT (Lore m) (State VNameSource)))
 -> BinderT (Lore m) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
img (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
img_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
iter]
        Result
imgs'' <- Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
Transformer m =>
Lambda (Lore m) -> [Exp (Lore m)] -> m Result
bindLambda Lambda (Lore m)
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
bucket_fun ([Exp (Lore (BinderT (Lore m) (State VNameSource)))]
 -> BinderT (Lore m) (State VNameSource) Result)
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ (SubExp -> ExpT (Lore m)) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) Result
imgs'

        -- Split out values from bucket function.
        let lens :: Int
lens = [HistOp (Lore m)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp (Lore m)]
ops
            inds :: Result
inds = Int -> Result -> Result
forall a. Int -> [a] -> [a]
take Int
lens Result
imgs''
            vals :: [Result]
vals = [Int] -> Result -> [Result]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp (Lore m) -> Int) -> [HistOp (Lore m)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int)
-> (HistOp (Lore m) -> [Type]) -> HistOp (Lore m) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda (Lore m) -> [Type])
-> (HistOp (Lore m) -> Lambda (Lore m))
-> HistOp (Lore m)
-> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Lore m) -> Lambda (Lore m)
forall lore. HistOp lore -> Lambda lore
histOp) [HistOp (Lore m)]
ops) (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
lens Result
imgs''
            hists_out' :: [[VName]]
hists_out' =
              [Int] -> [VName] -> [[VName]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp (Lore m) -> Int) -> [HistOp (Lore m)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int)
-> (HistOp (Lore m) -> [Type]) -> HistOp (Lore m) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda (Lore m) -> [Type])
-> (HistOp (Lore m) -> Lambda (Lore m))
-> HistOp (Lore m)
-> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Lore m) -> Lambda (Lore m)
forall lore. HistOp lore -> Lambda lore
histOp) [HistOp (Lore m)]
ops) ([VName] -> [[VName]]) -> [VName] -> [[VName]]
forall a b. (a -> b) -> a -> b
$
                (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
hists_out

        [[VName]]
hists_out'' <- [([VName], HistOp (Lore m), SubExp, Result)]
-> (([VName], HistOp (Lore m), SubExp, Result)
    -> BinderT (Lore m) (State VNameSource) [VName])
-> BinderT (Lore m) (State VNameSource) [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([[VName]]
-> [HistOp (Lore m)]
-> Result
-> [Result]
-> [([VName], HistOp (Lore m), SubExp, Result)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [[VName]]
hists_out' [HistOp (Lore m)]
ops Result
inds [Result]
vals) ((([VName], HistOp (Lore m), SubExp, Result)
  -> BinderT (Lore m) (State VNameSource) [VName])
 -> BinderT (Lore m) (State VNameSource) [[VName]])
-> (([VName], HistOp (Lore m), SubExp, Result)
    -> BinderT (Lore m) (State VNameSource) [VName])
-> BinderT (Lore m) (State VNameSource) [[VName]]
forall a b. (a -> b) -> a -> b
$ \([VName]
hist, HistOp (Lore m)
op, SubExp
idx, Result
val) -> do
          -- Check whether the indexes are in-bound.  If they are not, we
          -- return the histograms unchanged.
          let outside_bounds_branch :: BinderT
  (Lore m)
  (State VNameSource)
  (Body (Lore (BinderT (Lore m) (State VNameSource))))
outside_bounds_branch = BinderT
  (Lore m)
  (State VNameSource)
  (Body (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
     (Lore m)
     (State VNameSource)
     (Body (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (BinderT
   (Lore m)
   (State VNameSource)
   (Body (Lore (BinderT (Lore m) (State VNameSource))))
 -> BinderT
      (Lore m)
      (State VNameSource)
      (Body (Lore (BinderT (Lore m) (State VNameSource)))))
-> BinderT
     (Lore m)
     (State VNameSource)
     (Body (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
     (Lore m)
     (State VNameSource)
     (Body (Lore (BinderT (Lore m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Result
-> BinderT
     (Lore m)
     (State VNameSource)
     (Body (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result
 -> BinderT
      (Lore m)
      (State VNameSource)
      (Body (Lore (BinderT (Lore m) (State VNameSource)))))
-> Result
-> BinderT
     (Lore m)
     (State VNameSource)
     (Body (Lore (BinderT (Lore m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
hist
              oob :: BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
oob = case [VName]
hist of
                [] -> SubExp
-> BinderT
     (Lore m)
     (State VNameSource)
     (Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp
 -> BinderT
      (Lore m)
      (State VNameSource)
      (Exp (Lore (BinderT (Lore m) (State VNameSource)))))
-> SubExp
-> BinderT
     (Lore m)
     (State VNameSource)
     (Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
                VName
arr : [VName]
_ -> VName
-> [BinderT
      (Lore m)
      (State VNameSource)
      (Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT
     (Lore m)
     (State VNameSource)
     (Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
VName -> [m (Exp (Lore m))] -> m (Exp (Lore m))
eOutOfBounds VName
arr [SubExp
-> BinderT
     (Lore m)
     (State VNameSource)
     (Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
idx]

          [Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m [VName]
letTupExp [Char]
"new_histo"
            (ExpT (Lore m) -> BinderT (Lore m) (State VNameSource) [VName])
-> (Binder (Lore m) (Body (Lore m))
    -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> BinderT (Lore m) (State VNameSource) [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinderT
  (Lore m)
  (State VNameSource)
  (Exp (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
     (Lore m)
     (State VNameSource)
     (Body (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
     (Lore m)
     (State VNameSource)
     (Body (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
     (Lore m)
     (State VNameSource)
     (Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
BinderT
  (Lore m)
  (State VNameSource)
  (Exp (Lore (BinderT (Lore m) (State VNameSource))))
oob BinderT
  (Lore m)
  (State VNameSource)
  (Body (Lore (BinderT (Lore m) (State VNameSource))))
outside_bounds_branch
            (Binder (Lore m) (Body (Lore m))
 -> BinderT (Lore m) (State VNameSource) [VName])
-> Binder (Lore m) (Body (Lore m))
-> BinderT (Lore m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
              -- Read values from histogram.
              Result
h_val <- [VName]
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
hist ((VName -> BinderT (Lore m) (State VNameSource) SubExp)
 -> BinderT (Lore m) (State VNameSource) Result)
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
arr -> do
                Type
arr_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
                [Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"read_hist" (Exp (Lore (BinderT (Lore m) (State VNameSource)))
 -> BinderT (Lore m) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
idx]

              -- Apply operator.
              Result
h_val' <-
                Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
Transformer m =>
Lambda (Lore m) -> [Exp (Lore m)] -> m Result
bindLambda (HistOp (Lore m) -> Lambda (Lore m)
forall lore. HistOp lore -> Lambda lore
histOp HistOp (Lore m)
op) ([Exp (Lore (BinderT (Lore m) (State VNameSource)))]
 -> BinderT (Lore m) (State VNameSource) Result)
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
                  (SubExp -> ExpT (Lore m)) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) (Result -> [ExpT (Lore m)]) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> a -> b
$ Result
h_val Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
val

              -- Write values back to histograms.
              [VName]
hist' <- [(VName, SubExp)]
-> ((VName, SubExp) -> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
hist Result
h_val') (((VName, SubExp) -> BinderT (Lore m) (State VNameSource) VName)
 -> BinderT (Lore m) (State VNameSource) [VName])
-> ((VName, SubExp) -> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
v) -> do
                Type
arr_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
                [Char]
-> VName
-> Slice SubExp
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> VName -> Slice SubExp -> Exp (Lore m) -> m VName
letInPlace [Char]
"hist_out" VName
arr (Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
idx]) (Exp (Lore (BinderT (Lore m) (State VNameSource)))
 -> BinderT (Lore m) (State VNameSource) VName)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
                  BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v

              Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> Binder (Lore m) (Body (Lore m)))
-> Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ Result -> Body (Lore m)
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> Body (Lore m)) -> Result -> Body (Lore m)
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
hist'

        Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> Binder (Lore m) (Body (Lore m)))
-> Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ Result -> Body (Lore m)
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> Body (Lore m)) -> Result -> Body (Lore m)
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
hists_out''

  -- Wrap up the above into a for-loop.
  Pattern (Lore m) -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
pat (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Lore m), SubExp)]
-> [(FParam (Lore m), SubExp)]
-> LoopForm (Lore m)
-> Body (Lore m)
-> ExpT (Lore m)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam (Lore m), SubExp)]
merge (VName
-> IntType
-> SubExp
-> [(LParam (Lore m), VName)]
-> LoopForm (Lore m)
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
iter IntType
Int64 SubExp
len []) Body (Lore m)
loopBody

-- | Recursively first-order-transform a lambda.
transformLambda ::
  ( MonadFreshNames m,
    Bindable lore,
    BinderOps lore,
    LocalScope somelore m,
    SameScope somelore lore,
    LetDec lore ~ LetDec SOACS
  ) =>
  Lambda ->
  m (AST.Lambda lore)
transformLambda :: Lambda SOACS -> m (Lambda lore)
transformLambda (Lambda [LParam SOACS]
params BodyT SOACS
body [Type]
rettype) = do
  Body lore
body' <-
    Binder lore (Body lore) -> m (Body lore)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder lore (Body lore) -> m (Body lore))
-> Binder lore (Body lore) -> m (Body lore)
forall a b. (a -> b) -> a -> b
$
      Scope lore -> Binder lore (Body lore) -> Binder lore (Body lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope lore
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams [Param Type]
[LParam SOACS]
params) (Binder lore (Body lore) -> Binder lore (Body lore))
-> Binder lore (Body lore) -> Binder lore (Body lore)
forall a b. (a -> b) -> a -> b
$
        BodyT SOACS
-> BinderT
     lore
     (State VNameSource)
     (Body (Lore (BinderT lore (State VNameSource))))
forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
BodyT SOACS -> m (Body (Lore m))
transformBody BodyT SOACS
body
  Lambda lore -> m (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> m (Lambda lore)) -> Lambda lore -> m (Lambda lore)
forall a b. (a -> b) -> a -> b
$ [LParam lore] -> Body lore -> [Type] -> Lambda lore
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam lore]
[LParam SOACS]
params Body lore
body' [Type]
rettype

resultArray :: Transformer m => [Type] -> m [VName]
resultArray :: [Type] -> m [VName]
resultArray = (Type -> m VName) -> [Type] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> m VName
forall (m :: * -> *) u.
MonadBinder m =>
TypeBase Shape u -> m VName
oneArray
  where
    oneArray :: TypeBase Shape u -> m VName
oneArray TypeBase Shape u
t = [Char] -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m VName
letExp [Char]
"result" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ PrimType -> Result -> BasicOp
Scratch (TypeBase Shape u -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape u
t) (TypeBase Shape u -> Result
forall u. TypeBase Shape u -> Result
arrayDims TypeBase Shape u
t)

letwith ::
  Transformer m =>
  [VName] ->
  m (AST.Exp (Lore m)) ->
  [AST.Exp (Lore m)] ->
  m [VName]
letwith :: [VName] -> m (Exp (Lore m)) -> [Exp (Lore m)] -> m [VName]
letwith [VName]
ks m (Exp (Lore m))
i [Exp (Lore m)]
vs = do
  Result
vs' <- [Char] -> [Exp (Lore m)] -> m Result
forall (m :: * -> *).
MonadBinder m =>
[Char] -> [Exp (Lore m)] -> m Result
letSubExps [Char]
"values" [Exp (Lore m)]
vs
  SubExp
i' <- [Char] -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"i" (Exp (Lore m) -> m SubExp) -> m (Exp (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m))
i
  let update :: VName -> SubExp -> m VName
update VName
k SubExp
v = do
        Type
k_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
k
        [Char] -> VName -> Slice SubExp -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> VName -> Slice SubExp -> Exp (Lore m) -> m VName
letInPlace [Char]
"lw_dest" VName
k (Type -> Slice SubExp -> Slice SubExp
fullSlice Type
k_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i']) (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v
  (VName -> SubExp -> m VName) -> [VName] -> Result -> m [VName]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM VName -> SubExp -> m VName
update [VName]
ks Result
vs'

pexp :: Applicative f => SubExp -> f (AST.Exp lore)
pexp :: SubExp -> f (Exp lore)
pexp = Exp lore -> f (Exp lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp lore -> f (Exp lore))
-> (SubExp -> Exp lore) -> SubExp -> f (Exp lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> (SubExp -> BasicOp) -> SubExp -> Exp lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp

bindLambda ::
  Transformer m =>
  AST.Lambda (Lore m) ->
  [AST.Exp (Lore m)] ->
  m [SubExp]
bindLambda :: Lambda (Lore m) -> [Exp (Lore m)] -> m Result
bindLambda (Lambda [LParam (Lore m)]
params BodyT (Lore m)
body [Type]
_) [Exp (Lore m)]
args = do
  [(Param Type, Exp (Lore m))]
-> ((Param Type, Exp (Lore m)) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [Exp (Lore m)] -> [(Param Type, Exp (Lore m))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
[LParam (Lore m)]
params [Exp (Lore m)]
args) (((Param Type, Exp (Lore m)) -> m ()) -> m ())
-> ((Param Type, Exp (Lore m)) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
param, Exp (Lore m)
arg) ->
    if Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
param
      then [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
param] Exp (Lore m)
arg
      else [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
param] (Exp (Lore m) -> m ()) -> m (Exp (Lore m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m)) -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Exp (Lore m)) -> m (Exp (Lore m))
eCopy (Exp (Lore m) -> m (Exp (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp (Lore m)
arg)
  BodyT (Lore m) -> m Result
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m Result
bodyBind BodyT (Lore m)
body

loopMerge :: [Ident] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge :: [Ident] -> Result -> [(Param DeclType, SubExp)]
loopMerge [Ident]
vars = [(Ident, Uniqueness)] -> Result -> [(Param DeclType, SubExp)]
loopMerge' ([(Ident, Uniqueness)] -> Result -> [(Param DeclType, SubExp)])
-> [(Ident, Uniqueness)] -> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ [Ident] -> [Uniqueness] -> [(Ident, Uniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
vars ([Uniqueness] -> [(Ident, Uniqueness)])
-> [Uniqueness] -> [(Ident, Uniqueness)]
forall a b. (a -> b) -> a -> b
$ Uniqueness -> [Uniqueness]
forall a. a -> [a]
repeat Uniqueness
Unique

loopMerge' :: [(Ident, Uniqueness)] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge' :: [(Ident, Uniqueness)] -> Result -> [(Param DeclType, SubExp)]
loopMerge' [(Ident, Uniqueness)]
vars Result
vals =
  [ (VName -> DeclType -> Param DeclType
forall dec. VName -> dec -> Param dec
Param VName
pname (DeclType -> Param DeclType) -> DeclType -> Param DeclType
forall a b. (a -> b) -> a -> b
$ Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
ptype Uniqueness
u, SubExp
val)
    | ((Ident VName
pname Type
ptype, Uniqueness
u), SubExp
val) <- [(Ident, Uniqueness)] -> Result -> [((Ident, Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Ident, Uniqueness)]
vars Result
vals
  ]