{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
-- | The code generator cannot handle the array combinators (@map@ and
-- friends), so this module was written to transform them into the
-- equivalent do-loops.  The transformation is currently rather naive,
-- and - it's certainly worth considering when we can express such
-- transformations in-place.
module Futhark.Transform.FirstOrderTransform
  ( transformFunDef
  , transformStms

  , FirstOrderLore
  , Transformer
  , transformStmRecursively
  , transformLambda
  , transformSOAC
  , transformBody
  )
  where

import Control.Monad.Except
import Control.Monad.State
import qualified Data.Map.Strict as M
import Data.List (zip4)

import qualified Futhark.Representation.AST as AST
import Futhark.Representation.SOACS
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Representation.AST.Attributes.Aliases
import Futhark.Util (chunks, splitAt3)

-- | The constraints that must hold for a lore in order to be the
-- target of first-order transformation.
type FirstOrderLore lore =
  (Bindable lore, BinderOps lore,
   LetAttr SOACS ~ LetAttr lore,
   LParamAttr SOACS ~ LParamAttr lore,
   CanBeAliased (Op lore))

transformFunDef :: (MonadFreshNames m, FirstOrderLore tolore) =>
                   Scope tolore -> FunDef SOACS -> m (AST.FunDef tolore)
transformFunDef :: Scope tolore -> FunDef SOACS -> m (FunDef tolore)
transformFunDef Scope tolore
consts_scope (FunDef Maybe EntryPoint
entry Name
fname [RetType SOACS]
rettype [FParam SOACS]
params BodyT SOACS
body) = do
  (BodyT tolore
body',Stms tolore
_) <- (VNameSource -> ((BodyT tolore, Stms tolore), VNameSource))
-> m (BodyT tolore, Stms tolore)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((BodyT tolore, Stms tolore), VNameSource))
 -> m (BodyT tolore, Stms tolore))
-> (VNameSource -> ((BodyT tolore, Stms tolore), VNameSource))
-> m (BodyT tolore, Stms tolore)
forall a b. (a -> b) -> a -> b
$ State VNameSource (BodyT tolore, Stms tolore)
-> VNameSource -> ((BodyT tolore, Stms tolore), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (BodyT tolore, Stms tolore)
 -> VNameSource -> ((BodyT tolore, Stms tolore), VNameSource))
-> State VNameSource (BodyT tolore, Stms tolore)
-> VNameSource
-> ((BodyT tolore, Stms tolore), VNameSource)
forall a b. (a -> b) -> a -> b
$ BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
-> Scope tolore -> State VNameSource (BodyT tolore, Stms tolore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
m Scope tolore
consts_scope
  FunDef tolore -> m (FunDef tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef tolore -> m (FunDef tolore))
-> FunDef tolore -> m (FunDef tolore)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Name
-> [RetType tolore]
-> [FParam tolore]
-> BodyT tolore
-> FunDef tolore
forall lore.
Maybe EntryPoint
-> Name
-> [RetType lore]
-> [FParam lore]
-> BodyT lore
-> FunDef lore
FunDef Maybe EntryPoint
entry Name
fname [RetType tolore]
[RetType SOACS]
rettype [FParam tolore]
[FParam SOACS]
params BodyT tolore
body'
  where m :: BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
m = Scope tolore
-> BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
-> BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param DeclType] -> Scope tolore
forall lore attr.
(FParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfFParams [Param DeclType]
[FParam SOACS]
params) (BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
 -> BinderT tolore (StateT VNameSource Identity) (BodyT tolore))
-> BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
-> BinderT tolore (StateT VNameSource Identity) (BodyT tolore)
forall a b. (a -> b) -> a -> b
$ BinderT
  tolore
  (StateT VNameSource Identity)
  (Body (Lore (BinderT tolore (StateT VNameSource Identity))))
-> BinderT
     tolore
     (StateT VNameSource Identity)
     (Body (Lore (BinderT tolore (StateT VNameSource Identity))))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (BinderT
   tolore
   (StateT VNameSource Identity)
   (Body (Lore (BinderT tolore (StateT VNameSource Identity))))
 -> BinderT
      tolore
      (StateT VNameSource Identity)
      (Body (Lore (BinderT tolore (StateT VNameSource Identity)))))
-> BinderT
     tolore
     (StateT VNameSource Identity)
     (Body (Lore (BinderT tolore (StateT VNameSource Identity))))
-> BinderT
     tolore
     (StateT VNameSource Identity)
     (Body (Lore (BinderT tolore (StateT VNameSource Identity))))
forall a b. (a -> b) -> a -> b
$ BodyT SOACS
-> BinderT
     tolore
     (StateT VNameSource Identity)
     (Body (Lore (BinderT tolore (StateT VNameSource Identity))))
forall (m :: * -> *).
(Transformer m, LetAttr (Lore m) ~ LetAttr SOACS) =>
BodyT SOACS -> m (Body (Lore m))
transformBody BodyT SOACS
body

transformStms :: (MonadFreshNames m, FirstOrderLore tolore) =>
                   Stms SOACS -> m (AST.Stms tolore)
transformStms :: Stms SOACS -> m (Stms tolore)
transformStms Stms SOACS
stms =
  (((), Stms tolore) -> Stms tolore)
-> m ((), Stms tolore) -> m (Stms tolore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Stms tolore) -> Stms tolore
forall a b. (a, b) -> b
snd (m ((), Stms tolore) -> m (Stms tolore))
-> m ((), Stms tolore) -> m (Stms tolore)
forall a b. (a -> b) -> a -> b
$ (VNameSource -> (((), Stms tolore), VNameSource))
-> m ((), Stms tolore)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (((), Stms tolore), VNameSource))
 -> m ((), Stms tolore))
-> (VNameSource -> (((), Stms tolore), VNameSource))
-> m ((), Stms tolore)
forall a b. (a -> b) -> a -> b
$ State VNameSource ((), Stms tolore)
-> VNameSource -> (((), Stms tolore), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource ((), Stms tolore)
 -> VNameSource -> (((), Stms tolore), VNameSource))
-> State VNameSource ((), Stms tolore)
-> VNameSource
-> (((), Stms tolore), VNameSource)
forall a b. (a -> b) -> a -> b
$ BinderT tolore (StateT VNameSource Identity) ()
-> Scope tolore -> State VNameSource ((), Stms tolore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT tolore (StateT VNameSource Identity) ()
m Scope tolore
forall a. Monoid a => a
mempty
  where m :: BinderT tolore (StateT VNameSource Identity) ()
m = (Stm -> BinderT tolore (StateT VNameSource Identity) ())
-> Stms SOACS -> BinderT tolore (StateT VNameSource Identity) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm -> BinderT tolore (StateT VNameSource Identity) ()
forall (m :: * -> *).
(Transformer m, LetAttr (Lore m) ~ LetAttr SOACS) =>
Stm -> m ()
transformStmRecursively Stms SOACS
stms

-- | The constraints that a monad must uphold in order to be used for
-- first-order transformation.
type Transformer m = (MonadBinder m, LocalScope (Lore m) m,
                      Bindable (Lore m), BinderOps (Lore m),
                      LParamAttr SOACS ~ LParamAttr (Lore m),
                      CanBeAliased (Op (Lore m)))

transformBody :: (Transformer m, LetAttr (Lore m) ~ LetAttr SOACS) =>
                 Body -> m (AST.Body (Lore m))
transformBody :: BodyT SOACS -> m (Body (Lore m))
transformBody (Body () Stms SOACS
bnds Result
res) = m (Body (Lore m)) -> m (Body (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (m (Body (Lore m)) -> m (Body (Lore m)))
-> m (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
  (Stm -> m ()) -> Stms SOACS -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm -> m ()
forall (m :: * -> *).
(Transformer m, LetAttr (Lore m) ~ LetAttr SOACS) =>
Stm -> m ()
transformStmRecursively Stms SOACS
bnds
  Body (Lore m) -> m (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> m (Body (Lore m)))
-> Body (Lore m) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ Result -> Body (Lore m)
forall lore. Bindable lore => Result -> Body lore
resultBody Result
res

-- | First transform any nested 'Body' or 'Lambda' elements, then
-- apply 'transformSOAC' if the expression is a SOAC.
transformStmRecursively :: (Transformer m, LetAttr (Lore m) ~ LetAttr SOACS) =>
                           Stm -> m ()

transformStmRecursively :: Stm -> m ()
transformStmRecursively (Let Pattern SOACS
pat StmAux (ExpAttr SOACS)
aux (Op Op SOACS
soac)) =
  Certificates -> m () -> m ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpAttr SOACS)
aux) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
  Pattern (Lore m) -> SOAC (Lore m) -> m ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
transformSOAC Pattern (Lore m)
Pattern SOACS
pat (SOAC (Lore m) -> m ()) -> m (SOAC (Lore m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOACMapper SOACS (Lore m) m -> SOAC SOACS -> m (SOAC (Lore m))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper SOACS (Lore m) m
soacTransform Op SOACS
SOAC SOACS
soac
  where soacTransform :: SOACMapper SOACS (Lore m) m
soacTransform = SOACMapper Any Any m
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper { mapOnSOACLambda :: Lambda SOACS -> m (Lambda (Lore m))
mapOnSOACLambda = Lambda SOACS -> m (Lambda (Lore m))
forall (m :: * -> *) lore somelore.
(MonadFreshNames m, Bindable lore, BinderOps lore,
 LocalScope somelore m, SameScope somelore lore,
 LetAttr lore ~ LetAttr SOACS, CanBeAliased (Op lore)) =>
Lambda SOACS -> m (Lambda lore)
transformLambda }

transformStmRecursively (Let Pattern SOACS
pat StmAux (ExpAttr SOACS)
aux ExpT SOACS
e) =
  Certificates -> m () -> m ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpAttr SOACS)
aux) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
  Pattern (Lore m) -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern (Lore m)
Pattern SOACS
pat (ExpT (Lore m) -> m ()) -> m (ExpT (Lore m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Mapper SOACS (Lore m) m -> ExpT SOACS -> m (ExpT (Lore m))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper SOACS (Lore m) m
transform ExpT SOACS
e
  where transform :: Mapper SOACS (Lore m) m
transform = Mapper Any Any m
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope (Lore m) -> BodyT SOACS -> m (Body (Lore m))
mapOnBody = \Scope (Lore m)
scope -> Scope (Lore m) -> m (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope (Lore m)
scope (m (Body (Lore m)) -> m (Body (Lore m)))
-> (BodyT SOACS -> m (Body (Lore m)))
-> BodyT SOACS
-> m (Body (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyT SOACS -> m (Body (Lore m))
forall (m :: * -> *).
(Transformer m, LetAttr (Lore m) ~ LetAttr SOACS) =>
BodyT SOACS -> m (Body (Lore m))
transformBody
                                   , mapOnRetType :: RetType SOACS -> m (RetType (Lore m))
mapOnRetType = RetType SOACS -> m (RetType (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return
                                   , mapOnBranchType :: BranchType SOACS -> m (BranchType (Lore m))
mapOnBranchType = BranchType SOACS -> m (BranchType (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return
                                   , mapOnFParam :: FParam SOACS -> m (FParam (Lore m))
mapOnFParam = FParam SOACS -> m (FParam (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return
                                   , mapOnLParam :: LParam SOACS -> m (LParam (Lore m))
mapOnLParam = LParam SOACS -> m (LParam (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return
                                   , mapOnOp :: Op SOACS -> m (Op (Lore m))
mapOnOp = [Char] -> SOAC SOACS -> m (Op (Lore m))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled Op in first order transform"
                                   }

-- | Transform a single 'SOAC' into a do-loop.  The body of the lambda
-- is untouched, and may or may not contain further 'SOAC's depending
-- on the given lore.
transformSOAC :: Transformer m =>
                 AST.Pattern (Lore m)
              -> SOAC (Lore m)
              -> m ()

transformSOAC :: Pattern (Lore m) -> SOAC (Lore m) -> m ()
transformSOAC Pattern (Lore m)
pat (Screma SubExp
w form :: ScremaForm (Lore m)
form@(ScremaForm [Scan (Lore m)]
scans [Reduce (Lore m)]
reds Lambda (Lore m)
map_lam) [VName]
arrs) = do
  -- Start by combining all the reduction parts into a single operator
  let Reduce Commutativity
_ Lambda (Lore m)
red_lam Result
red_nes = [Reduce (Lore m)] -> Reduce (Lore m)
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce (Lore m)]
reds
      Scan Lambda (Lore m)
scan_lam Result
scan_nes = [Scan (Lore m)] -> Scan (Lore m)
forall lore. Bindable lore => [Scan lore] -> Scan lore
singleScan [Scan (Lore m)]
scans
      ([Type]
scan_arr_ts, [Type]
_red_ts, [Type]
map_arr_ts) =
        Int -> Int -> [Type] -> ([Type], [Type], [Type])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
scan_nes) (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
red_nes) ([Type] -> ([Type], [Type], [Type]))
-> [Type] -> ([Type], [Type], [Type])
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm (Lore m) -> [Type]
forall lore. SubExp -> ScremaForm lore -> [Type]
scremaType SubExp
w ScremaForm (Lore m)
form
  [VName]
scan_arrs <- [Type] -> m [VName]
forall (m :: * -> *). Transformer m => [Type] -> m [VName]
resultArray [Type]
scan_arr_ts
  [VName]
map_arrs <- [Type] -> m [VName]
forall (m :: * -> *). Transformer m => [Type] -> m [VName]
resultArray [Type]
map_arr_ts

  -- We construct a loop that contains several groups of merge
  -- parameters:
  --
  -- (0) scan accumulator.
  -- (1) scan results.
  -- (2) reduce results (and accumulator).
  -- (3) map results.
  --
  -- Inside the loop, the parameters to map_lam become for-in
  -- parameters.

  [Param DeclType]
scanacc_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) attr.
MonadFreshNames m =>
[Char] -> attr -> m (Param attr)
newParam [Char]
"scanacc" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) ([Type] -> m [Param DeclType]) -> [Type] -> m [Param DeclType]
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
scan_lam
  [Param DeclType]
scanout_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) attr.
MonadFreshNames m =>
[Char] -> attr -> m (Param attr)
newParam [Char]
"scanout" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Unique) [Type]
scan_arr_ts
  [Param DeclType]
redout_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) attr.
MonadFreshNames m =>
[Char] -> attr -> m (Param attr)
newParam [Char]
"redout" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) ([Type] -> m [Param DeclType]) -> [Type] -> m [Param DeclType]
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
red_lam
  [Param DeclType]
mapout_params <- (Type -> m (Param DeclType)) -> [Type] -> m [Param DeclType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> DeclType -> m (Param DeclType)
forall (m :: * -> *) attr.
MonadFreshNames m =>
[Char] -> attr -> m (Param attr)
newParam [Char]
"mapout" (DeclType -> m (Param DeclType))
-> (Type -> DeclType) -> Type -> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> DeclType) -> Uniqueness -> Type -> DeclType
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Unique) [Type]
map_arr_ts

  let merge :: [(Param DeclType, SubExp)]
merge = [[(Param DeclType, SubExp)]] -> [(Param DeclType, SubExp)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
scanacc_params Result
scan_nes,
                      [Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
scanout_params (Result -> [(Param DeclType, SubExp)])
-> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
scan_arrs,
                      [Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
redout_params Result
red_nes,
                      [Param DeclType] -> Result -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
mapout_params (Result -> [(Param DeclType, SubExp)])
-> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
map_arrs]
  VName
i <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i"
  let loopform :: LoopForm (Lore m)
loopform = VName
-> IntType
-> SubExp
-> [(LParam (Lore m), VName)]
-> LoopForm (Lore m)
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
Int32 SubExp
w []

  Body (Lore m)
loop_body <- Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
               Scope (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param DeclType] -> Scope (Lore m)
forall lore attr.
(FParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfFParams ([Param DeclType] -> Scope (Lore m))
-> [Param DeclType] -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge) (Binder (Lore m) (Body (Lore m))
 -> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
               LoopForm (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf LoopForm (Lore m)
loopform (Binder (Lore m) (Body (Lore m))
 -> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do

    [(Param Type, VName)]
-> ((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda (Lore m) -> [LParam (Lore m)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Lore m)
map_lam) [VName]
arrs) (((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
 -> BinderT (Lore m) (State VNameSource) ())
-> ((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) -> do
      Type
arr_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
      [VName]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
p] (Exp (Lore (BinderT (Lore m) (State VNameSource)))
 -> BinderT (Lore m) (State VNameSource) ())
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
        Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i]

    -- Insert the statements of the lambda.  We have taken care to
    -- ensure that the parameters are bound at this point.
    (Stm (Lore m) -> BinderT (Lore m) (State VNameSource) ())
-> Seq (Stm (Lore m)) -> BinderT (Lore m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm (Lore m) -> BinderT (Lore m) (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Seq (Stm (Lore m)) -> BinderT (Lore m) (State VNameSource) ())
-> Seq (Stm (Lore m)) -> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Body (Lore m) -> Seq (Stm (Lore m))
forall lore. BodyT lore -> Stms lore
bodyStms (Body (Lore m) -> Seq (Stm (Lore m)))
-> Body (Lore m) -> Seq (Stm (Lore m))
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> Body (Lore m)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
map_lam
    -- Split into scan results, reduce results, and map results.
    let (Result
scan_res, Result
red_res, Result
map_res) =
          Int -> Int -> Result -> (Result, Result, Result)
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
scan_nes) (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
red_nes) (Result -> (Result, Result, Result))
-> Result -> (Result, Result, Result)
forall a b. (a -> b) -> a -> b
$
          Body (Lore m) -> Result
forall lore. BodyT lore -> Result
bodyResult (Body (Lore m) -> Result) -> Body (Lore m) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> Body (Lore m)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
map_lam

    Result
scan_res' <- Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [BinderT
      (Lore m)
      (State VNameSource)
      (Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
MonadBinder m =>
Lambda (Lore m) -> [m (Exp (Lore m))] -> m Result
eLambda Lambda (Lore m)
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
scan_lam ([BinderT
    (Lore m)
    (State VNameSource)
    (Exp (Lore (BinderT (Lore m) (State VNameSource))))]
 -> BinderT (Lore m) (State VNameSource) Result)
-> [BinderT
      (Lore m)
      (State VNameSource)
      (Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ (SubExp -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))]
forall a b. (a -> b) -> [a] -> [b]
map (ExpT (Lore m)
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT (Lore m)
 -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> (SubExp -> ExpT (Lore m))
-> SubExp
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) (Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))])
-> Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))]
forall a b. (a -> b) -> a -> b
$
                 (Param DeclType -> SubExp) -> [Param DeclType] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param DeclType -> VName) -> Param DeclType -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> VName
forall attr. Param attr -> VName
paramName) [Param DeclType]
scanacc_params Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
scan_res
    Result
red_res' <- Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [BinderT
      (Lore m)
      (State VNameSource)
      (Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
MonadBinder m =>
Lambda (Lore m) -> [m (Exp (Lore m))] -> m Result
eLambda Lambda (Lore m)
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
red_lam ([BinderT
    (Lore m)
    (State VNameSource)
    (Exp (Lore (BinderT (Lore m) (State VNameSource))))]
 -> BinderT (Lore m) (State VNameSource) Result)
-> [BinderT
      (Lore m)
      (State VNameSource)
      (Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ (SubExp -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))]
forall a b. (a -> b) -> [a] -> [b]
map (ExpT (Lore m)
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT (Lore m)
 -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> (SubExp -> ExpT (Lore m))
-> SubExp
-> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) (Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))])
-> Result -> [BinderT (Lore m) (State VNameSource) (ExpT (Lore m))]
forall a b. (a -> b) -> a -> b
$
                (Param DeclType -> SubExp) -> [Param DeclType] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param DeclType -> VName) -> Param DeclType -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> VName
forall attr. Param attr -> VName
paramName) [Param DeclType]
redout_params Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
red_res

    -- Write the scan accumulator to the scan result arrays.
    [VName]
scan_outarrs <- [VName]
-> BinderT
     (Lore m)
     (State VNameSource)
     (Exp (Lore (BinderT (Lore m) (State VNameSource))))
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) [VName]
forall (m :: * -> *).
Transformer m =>
[VName] -> m (Exp (Lore m)) -> [Exp (Lore m)] -> m [VName]
letwith ((Param DeclType -> VName) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> VName
forall attr. Param attr -> VName
paramName [Param DeclType]
scanout_params) (SubExp -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall (f :: * -> *) lore. Applicative f => SubExp -> f (Exp lore)
pexp (VName -> SubExp
Var VName
i)) ([Exp (Lore (BinderT (Lore m) (State VNameSource)))]
 -> BinderT (Lore m) (State VNameSource) [VName])
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
                    (SubExp -> ExpT (Lore m)) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) Result
scan_res'

    -- Write the map results to the map result arrays.
    [VName]
map_outarrs <- [VName]
-> BinderT
     (Lore m)
     (State VNameSource)
     (Exp (Lore (BinderT (Lore m) (State VNameSource))))
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) [VName]
forall (m :: * -> *).
Transformer m =>
[VName] -> m (Exp (Lore m)) -> [Exp (Lore m)] -> m [VName]
letwith ((Param DeclType -> VName) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> VName
forall attr. Param attr -> VName
paramName [Param DeclType]
mapout_params) (SubExp -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
forall (f :: * -> *) lore. Applicative f => SubExp -> f (Exp lore)
pexp (VName -> SubExp
Var VName
i)) ([Exp (Lore (BinderT (Lore m) (State VNameSource)))]
 -> BinderT (Lore m) (State VNameSource) [VName])
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$
                   (SubExp -> ExpT (Lore m)) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) Result
map_res

    Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> Binder (Lore m) (Body (Lore m)))
-> Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ Result -> Body (Lore m)
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> Body (Lore m)) -> Result -> Body (Lore m)
forall a b. (a -> b) -> a -> b
$ [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [Result
scan_res',
                                  (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
scan_outarrs,
                                  Result
red_res',
                                  (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
map_outarrs]

  -- We need to discard the final scan accumulators, as they are not
  -- bound in the original pattern.
  [VName]
names <- ([VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++Pattern (Lore m) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames Pattern (Lore m)
pat)
           ([VName] -> [VName]) -> m [VName] -> m [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([Param DeclType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param DeclType]
scanacc_params) ([Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"discard")
  [VName] -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName]
names (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Lore m), SubExp)]
-> [(FParam (Lore m), SubExp)]
-> LoopForm (Lore m)
-> Body (Lore m)
-> ExpT (Lore m)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam (Lore m), SubExp)]
merge LoopForm (Lore m)
loopform Body (Lore m)
loop_body

transformSOAC Pattern (Lore m)
pat (Stream SubExp
w StreamForm (Lore m)
form Lambda (Lore m)
lam [VName]
arrs) =
  Pattern (Lore m)
-> SubExp -> Result -> Lambda (Lore m) -> [VName] -> m ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> Result -> LambdaT (Lore m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Lore m)
pat SubExp
w Result
nes Lambda (Lore m)
lam [VName]
arrs
  where nes :: Result
nes = StreamForm (Lore m) -> Result
forall lore. StreamForm lore -> Result
getStreamAccums StreamForm (Lore m)
form

transformSOAC Pattern (Lore m)
pat (Scatter SubExp
len Lambda (Lore m)
lam [VName]
ivs [(SubExp, Int, VName)]
as) = do
  VName
iter <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"write_iter"

  let (Result
_as_ws, [Int]
as_ns, [VName]
as_vs) = [(SubExp, Int, VName)] -> (Result, [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, Int, VName)]
as
  [Type]
ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
as_vs
  [Ident]
asOuts <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent [Char]
"write_out") [Type]
ts

  let ivsLen :: Int
ivsLen = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
lam) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2

  -- Scatter is in-place, so we use the input array as the output array.
  let merge :: [(Param DeclType, SubExp)]
merge = [Ident] -> Result -> [(Param DeclType, SubExp)]
loopMerge [Ident]
asOuts (Result -> [(Param DeclType, SubExp)])
-> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
as_vs
  Body (Lore m)
loopBody <- Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
    Scope (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (VName -> NameInfo (Lore m) -> Scope (Lore m) -> Scope (Lore m)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
iter (IntType -> NameInfo (Lore m)
forall lore. IntType -> NameInfo lore
IndexInfo IntType
Int32) (Scope (Lore m) -> Scope (Lore m))
-> Scope (Lore m) -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$
                [Param DeclType] -> Scope (Lore m)
forall lore attr.
(FParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfFParams ([Param DeclType] -> Scope (Lore m))
-> [Param DeclType] -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge) (Binder (Lore m) (Body (Lore m))
 -> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
    Result
ivs' <- [VName]
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ivs ((VName -> BinderT (Lore m) (State VNameSource) SubExp)
 -> BinderT (Lore m) (State VNameSource) Result)
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
iv -> do
      Type
iv_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
iv
      [Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"write_iv" (Exp (Lore (BinderT (Lore m) (State VNameSource)))
 -> BinderT (Lore m) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
iv (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
iv_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
iter]
    Result
ivs'' <- Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
Transformer m =>
Lambda (Lore m) -> [Exp (Lore m)] -> m Result
bindLambda Lambda (Lore m)
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
lam ((SubExp -> ExpT (Lore m)) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) Result
ivs')

    let indexes :: [Result]
indexes = [Int] -> Result -> [Result]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
take Int
ivsLen Result
ivs''
        values :: [Result]
values = [Int] -> Result -> [Result]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
ivsLen Result
ivs''

    [VName]
ress <- [(Result, Result, VName)]
-> ((Result, Result, VName)
    -> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Result] -> [Result] -> [VName] -> [(Result, Result, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Result]
indexes [Result]
values ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
asOuts)) (((Result, Result, VName)
  -> BinderT (Lore m) (State VNameSource) VName)
 -> BinderT (Lore m) (State VNameSource) [VName])
-> ((Result, Result, VName)
    -> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ \(Result
indexes', Result
values', VName
arr) -> do
      let saveInArray :: VName -> (SubExp, SubExp) -> m VName
saveInArray VName
arr' (SubExp
indexCur, SubExp
valueCur) =
            [Char] -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m VName
letExp [Char]
"write_out" (Exp (Lore m) -> m VName) -> m (Exp (Lore m)) -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [m (Exp (Lore m))] -> m (Exp (Lore m)) -> m (Exp (Lore m))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
VName -> [m (Exp (Lore m))] -> m (Exp (Lore m)) -> m (Exp (Lore m))
eWriteArray VName
arr' [SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
indexCur] (SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
valueCur)

      (VName
 -> (SubExp, SubExp) -> BinderT (Lore m) (State VNameSource) VName)
-> VName
-> [(SubExp, SubExp)]
-> BinderT (Lore m) (State VNameSource) VName
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM VName
-> (SubExp, SubExp) -> BinderT (Lore m) (State VNameSource) VName
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
VName -> (SubExp, SubExp) -> m VName
saveInArray VName
arr ([(SubExp, SubExp)] -> BinderT (Lore m) (State VNameSource) VName)
-> [(SubExp, SubExp)] -> BinderT (Lore m) (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ Result -> Result -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
indexes' Result
values'
    Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> Binder (Lore m) (Body (Lore m)))
-> Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ Result -> Body (Lore m)
forall lore. Bindable lore => Result -> Body lore
resultBody ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
ress)
  Pattern (Lore m) -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern (Lore m)
pat (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Lore m), SubExp)]
-> [(FParam (Lore m), SubExp)]
-> LoopForm (Lore m)
-> Body (Lore m)
-> ExpT (Lore m)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam (Lore m), SubExp)]
merge (VName
-> IntType
-> SubExp
-> [(LParam (Lore m), VName)]
-> LoopForm (Lore m)
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
iter IntType
Int32 SubExp
len []) Body (Lore m)
loopBody

transformSOAC Pattern (Lore m)
pat (Hist SubExp
len [HistOp (Lore m)]
ops Lambda (Lore m)
bucket_fun [VName]
imgs) = do
  VName
iter <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"iter"

  -- Bind arguments to parameters for the merge-variables.
  [Type]
hists_ts  <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType ([VName] -> m [Type]) -> [VName] -> m [Type]
forall a b. (a -> b) -> a -> b
$ (HistOp (Lore m) -> [VName]) -> [HistOp (Lore m)] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp (Lore m) -> [VName]
forall lore. HistOp lore -> [VName]
histDest [HistOp (Lore m)]
ops
  [Ident]
hists_out <- (Type -> m Ident) -> [Type] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent [Char]
"dests") [Type]
hists_ts
  let merge :: [(Param DeclType, SubExp)]
merge = [Ident] -> Result -> [(Param DeclType, SubExp)]
loopMerge [Ident]
hists_out (Result -> [(Param DeclType, SubExp)])
-> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (HistOp (Lore m) -> Result) -> [HistOp (Lore m)] -> Result
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> Result)
-> (HistOp (Lore m) -> [VName]) -> HistOp (Lore m) -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Lore m) -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp (Lore m)]
ops

  -- Bind lambda-bodies for operators.
  Body (Lore m)
loopBody <- Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$
    Scope (Lore m)
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (VName -> NameInfo (Lore m) -> Scope (Lore m) -> Scope (Lore m)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
iter (IntType -> NameInfo (Lore m)
forall lore. IntType -> NameInfo lore
IndexInfo IntType
Int32) (Scope (Lore m) -> Scope (Lore m))
-> Scope (Lore m) -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$
                [Param DeclType] -> Scope (Lore m)
forall lore attr.
(FParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfFParams ([Param DeclType] -> Scope (Lore m))
-> [Param DeclType] -> Scope (Lore m)
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge) (Binder (Lore m) (Body (Lore m))
 -> Binder (Lore m) (Body (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do

    -- Bind images to parameters of bucket function.
    Result
imgs' <- [VName]
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
imgs ((VName -> BinderT (Lore m) (State VNameSource) SubExp)
 -> BinderT (Lore m) (State VNameSource) Result)
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
img -> do
      Type
img_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
img
      [Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"pixel" (Exp (Lore (BinderT (Lore m) (State VNameSource)))
 -> BinderT (Lore m) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
img (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
img_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
iter]
    Result
imgs'' <- Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
Transformer m =>
Lambda (Lore m) -> [Exp (Lore m)] -> m Result
bindLambda Lambda (Lore m)
Lambda (Lore (BinderT (Lore m) (State VNameSource)))
bucket_fun ([Exp (Lore (BinderT (Lore m) (State VNameSource)))]
 -> BinderT (Lore m) (State VNameSource) Result)
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ (SubExp -> ExpT (Lore m)) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) Result
imgs'

    -- Split out values from bucket function.
    let lens :: Int
lens = [HistOp (Lore m)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp (Lore m)]
ops
        inds :: Result
inds = Int -> Result -> Result
forall a. Int -> [a] -> [a]
take Int
lens Result
imgs''
        vals :: [Result]
vals = [Int] -> Result -> [Result]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp (Lore m) -> Int) -> [HistOp (Lore m)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int)
-> (HistOp (Lore m) -> [Type]) -> HistOp (Lore m) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda (Lore m) -> [Type])
-> (HistOp (Lore m) -> Lambda (Lore m))
-> HistOp (Lore m)
-> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Lore m) -> Lambda (Lore m)
forall lore. HistOp lore -> Lambda lore
histOp) [HistOp (Lore m)]
ops) (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
lens Result
imgs''
        hists_out' :: [[VName]]
hists_out' = [Int] -> [VName] -> [[VName]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp (Lore m) -> Int) -> [HistOp (Lore m)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int)
-> (HistOp (Lore m) -> [Type]) -> HistOp (Lore m) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Lore m) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda (Lore m) -> [Type])
-> (HistOp (Lore m) -> Lambda (Lore m))
-> HistOp (Lore m)
-> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Lore m) -> Lambda (Lore m)
forall lore. HistOp lore -> Lambda lore
histOp) [HistOp (Lore m)]
ops) ([VName] -> [[VName]]) -> [VName] -> [[VName]]
forall a b. (a -> b) -> a -> b
$
                     (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
hists_out

    [[VName]]
hists_out'' <- [([VName], HistOp (Lore m), SubExp, Result)]
-> (([VName], HistOp (Lore m), SubExp, Result)
    -> BinderT (Lore m) (State VNameSource) [VName])
-> BinderT (Lore m) (State VNameSource) [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([[VName]]
-> [HistOp (Lore m)]
-> Result
-> [Result]
-> [([VName], HistOp (Lore m), SubExp, Result)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [[VName]]
hists_out' [HistOp (Lore m)]
ops Result
inds [Result]
vals) ((([VName], HistOp (Lore m), SubExp, Result)
  -> BinderT (Lore m) (State VNameSource) [VName])
 -> BinderT (Lore m) (State VNameSource) [[VName]])
-> (([VName], HistOp (Lore m), SubExp, Result)
    -> BinderT (Lore m) (State VNameSource) [VName])
-> BinderT (Lore m) (State VNameSource) [[VName]]
forall a b. (a -> b) -> a -> b
$ \([VName]
hist, HistOp (Lore m)
op, SubExp
idx, Result
val) -> do
      -- Check whether the indexes are in-bound.  If they are not, we
      -- return the histograms unchanged.
      let outside_bounds_branch :: BinderT
  (Lore m)
  (State VNameSource)
  (Body (Lore (BinderT (Lore m) (State VNameSource))))
outside_bounds_branch = BinderT
  (Lore m)
  (State VNameSource)
  (Body (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
     (Lore m)
     (State VNameSource)
     (Body (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (BinderT
   (Lore m)
   (State VNameSource)
   (Body (Lore (BinderT (Lore m) (State VNameSource))))
 -> BinderT
      (Lore m)
      (State VNameSource)
      (Body (Lore (BinderT (Lore m) (State VNameSource)))))
-> BinderT
     (Lore m)
     (State VNameSource)
     (Body (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
     (Lore m)
     (State VNameSource)
     (Body (Lore (BinderT (Lore m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Result
-> BinderT
     (Lore m)
     (State VNameSource)
     (Body (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result
 -> BinderT
      (Lore m)
      (State VNameSource)
      (Body (Lore (BinderT (Lore m) (State VNameSource)))))
-> Result
-> BinderT
     (Lore m)
     (State VNameSource)
     (Body (Lore (BinderT (Lore m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
hist
          oob :: BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
oob = case [VName]
hist of [] -> SubExp
-> BinderT
     (Lore m)
     (State VNameSource)
     (Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp
 -> BinderT
      (Lore m)
      (State VNameSource)
      (Exp (Lore (BinderT (Lore m) (State VNameSource)))))
-> SubExp
-> BinderT
     (Lore m)
     (State VNameSource)
     (Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
                             VName
arr:[VName]
_ -> VName
-> [BinderT
      (Lore m)
      (State VNameSource)
      (Exp (Lore (BinderT (Lore m) (State VNameSource))))]
-> BinderT
     (Lore m)
     (State VNameSource)
     (Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
VName -> [m (Exp (Lore m))] -> m (Exp (Lore m))
eOutOfBounds VName
arr [SubExp
-> BinderT
     (Lore m)
     (State VNameSource)
     (Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
idx]

      [Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) [VName]
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m [VName]
letTupExp [Char]
"new_histo" (ExpT (Lore m) -> BinderT (Lore m) (State VNameSource) [VName])
-> (Binder (Lore m) (Body (Lore m))
    -> BinderT (Lore m) (State VNameSource) (ExpT (Lore m)))
-> Binder (Lore m) (Body (Lore m))
-> BinderT (Lore m) (State VNameSource) [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=<
        BinderT
  (Lore m)
  (State VNameSource)
  (Exp (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
     (Lore m)
     (State VNameSource)
     (Body (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
     (Lore m)
     (State VNameSource)
     (Body (Lore (BinderT (Lore m) (State VNameSource))))
-> BinderT
     (Lore m)
     (State VNameSource)
     (Exp (Lore (BinderT (Lore m) (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf BinderT (Lore m) (State VNameSource) (ExpT (Lore m))
BinderT
  (Lore m)
  (State VNameSource)
  (Exp (Lore (BinderT (Lore m) (State VNameSource))))
oob BinderT
  (Lore m)
  (State VNameSource)
  (Body (Lore (BinderT (Lore m) (State VNameSource))))
outside_bounds_branch (Binder (Lore m) (Body (Lore m))
 -> BinderT (Lore m) (State VNameSource) [VName])
-> Binder (Lore m) (Body (Lore m))
-> BinderT (Lore m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$ do
        -- Read values from histogram.
        Result
h_val <- [VName]
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
hist ((VName -> BinderT (Lore m) (State VNameSource) SubExp)
 -> BinderT (Lore m) (State VNameSource) Result)
-> (VName -> BinderT (Lore m) (State VNameSource) SubExp)
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$ \VName
arr -> do
          Type
arr_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
          [Char]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"read_hist" (Exp (Lore (BinderT (Lore m) (State VNameSource)))
 -> BinderT (Lore m) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
idx]

        -- Apply operator.
        Result
h_val' <- Lambda (Lore (BinderT (Lore m) (State VNameSource)))
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall (m :: * -> *).
Transformer m =>
Lambda (Lore m) -> [Exp (Lore m)] -> m Result
bindLambda (HistOp (Lore m) -> Lambda (Lore m)
forall lore. HistOp lore -> Lambda lore
histOp HistOp (Lore m)
op) ([Exp (Lore (BinderT (Lore m) (State VNameSource)))]
 -> BinderT (Lore m) (State VNameSource) Result)
-> [Exp (Lore (BinderT (Lore m) (State VNameSource)))]
-> BinderT (Lore m) (State VNameSource) Result
forall a b. (a -> b) -> a -> b
$
                  (SubExp -> ExpT (Lore m)) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> ExpT (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) (Result -> [ExpT (Lore m)]) -> Result -> [ExpT (Lore m)]
forall a b. (a -> b) -> a -> b
$ Result
h_val Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
val

        -- Write values back to histograms.
        [VName]
hist' <- [(VName, SubExp)]
-> ((VName, SubExp) -> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
hist Result
h_val') (((VName, SubExp) -> BinderT (Lore m) (State VNameSource) VName)
 -> BinderT (Lore m) (State VNameSource) [VName])
-> ((VName, SubExp) -> BinderT (Lore m) (State VNameSource) VName)
-> BinderT (Lore m) (State VNameSource) [VName]
forall a b. (a -> b) -> a -> b
$  \(VName
arr, SubExp
v) -> do
          Type
arr_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
          [Char]
-> VName
-> Slice SubExp
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> VName -> Slice SubExp -> Exp (Lore m) -> m VName
letInPlace [Char]
"hist_out" VName
arr (Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
idx]) (Exp (Lore (BinderT (Lore m) (State VNameSource)))
 -> BinderT (Lore m) (State VNameSource) VName)
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v

        Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> Binder (Lore m) (Body (Lore m)))
-> Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ Result -> Body (Lore m)
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> Body (Lore m)) -> Result -> Body (Lore m)
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
hist'

    Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> Binder (Lore m) (Body (Lore m)))
-> Body (Lore m) -> Binder (Lore m) (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ Result -> Body (Lore m)
forall lore. Bindable lore => Result -> Body lore
resultBody (Result -> Body (Lore m)) -> Result -> Body (Lore m)
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
hists_out''

  -- Wrap up the above into a for-loop.
  Pattern (Lore m) -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern (Lore m)
pat (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ [(FParam (Lore m), SubExp)]
-> [(FParam (Lore m), SubExp)]
-> LoopForm (Lore m)
-> Body (Lore m)
-> ExpT (Lore m)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam (Lore m), SubExp)]
merge (VName
-> IntType
-> SubExp
-> [(LParam (Lore m), VName)]
-> LoopForm (Lore m)
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
iter IntType
Int32 SubExp
len []) Body (Lore m)
loopBody

-- | Recursively first-order-transform a lambda.
transformLambda :: (MonadFreshNames m,
                    Bindable lore, BinderOps lore,
                    LocalScope somelore m,
                    SameScope somelore lore,
                    LetAttr lore ~ LetAttr SOACS,
                    CanBeAliased (Op lore)) =>
                   Lambda -> m (AST.Lambda lore)
transformLambda :: Lambda SOACS -> m (Lambda lore)
transformLambda (Lambda [LParam SOACS]
params BodyT SOACS
body [Type]
rettype) = do
  Body lore
body' <- Binder lore (Body lore) -> m (Body lore)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder lore (Body lore) -> m (Body lore))
-> Binder lore (Body lore) -> m (Body lore)
forall a b. (a -> b) -> a -> b
$
           Scope lore -> Binder lore (Body lore) -> Binder lore (Body lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope lore
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams [Param Type]
[LParam SOACS]
params) (Binder lore (Body lore) -> Binder lore (Body lore))
-> Binder lore (Body lore) -> Binder lore (Body lore)
forall a b. (a -> b) -> a -> b
$
           BodyT SOACS
-> BinderT
     lore
     (State VNameSource)
     (Body (Lore (BinderT lore (State VNameSource))))
forall (m :: * -> *).
(Transformer m, LetAttr (Lore m) ~ LetAttr SOACS) =>
BodyT SOACS -> m (Body (Lore m))
transformBody BodyT SOACS
body
  Lambda lore -> m (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> m (Lambda lore)) -> Lambda lore -> m (Lambda lore)
forall a b. (a -> b) -> a -> b
$ [LParam lore] -> Body lore -> [Type] -> Lambda lore
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam lore]
[LParam SOACS]
params Body lore
body' [Type]
rettype

resultArray :: Transformer m => [Type] -> m [VName]
resultArray :: [Type] -> m [VName]
resultArray = (Type -> m VName) -> [Type] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> m VName
forall (m :: * -> *) u.
MonadBinder m =>
TypeBase Shape u -> m VName
oneArray
  where oneArray :: TypeBase Shape u -> m VName
oneArray TypeBase Shape u
t = [Char] -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m VName
letExp [Char]
"result" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ PrimType -> Result -> BasicOp
Scratch (TypeBase Shape u -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape u
t) (TypeBase Shape u -> Result
forall u. TypeBase Shape u -> Result
arrayDims TypeBase Shape u
t)

letwith :: Transformer m =>
           [VName] -> m (AST.Exp (Lore m)) -> [AST.Exp (Lore m)]
        -> m [VName]
letwith :: [VName] -> m (Exp (Lore m)) -> [Exp (Lore m)] -> m [VName]
letwith [VName]
ks m (Exp (Lore m))
i [Exp (Lore m)]
vs = do
  Result
vs' <- [Char] -> [Exp (Lore m)] -> m Result
forall (m :: * -> *).
MonadBinder m =>
[Char] -> [Exp (Lore m)] -> m Result
letSubExps [Char]
"values" [Exp (Lore m)]
vs
  SubExp
i' <- [Char] -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"i" (Exp (Lore m) -> m SubExp) -> m (Exp (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m))
i
  let update :: VName -> SubExp -> m VName
update VName
k SubExp
v = do
        Type
k_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
k
        [Char] -> VName -> Slice SubExp -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> VName -> Slice SubExp -> Exp (Lore m) -> m VName
letInPlace [Char]
"lw_dest" VName
k (Type -> Slice SubExp -> Slice SubExp
fullSlice Type
k_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i']) (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v
  (VName -> SubExp -> m VName) -> [VName] -> Result -> m [VName]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM VName -> SubExp -> m VName
update [VName]
ks Result
vs'

pexp :: Applicative f => SubExp -> f (AST.Exp lore)
pexp :: SubExp -> f (Exp lore)
pexp = Exp lore -> f (Exp lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp lore -> f (Exp lore))
-> (SubExp -> Exp lore) -> SubExp -> f (Exp lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> (SubExp -> BasicOp) -> SubExp -> Exp lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp

bindLambda :: Transformer m =>
              AST.Lambda (Lore m) -> [AST.Exp (Lore m)]
           -> m [SubExp]
bindLambda :: Lambda (Lore m) -> [Exp (Lore m)] -> m Result
bindLambda (Lambda [LParam (Lore m)]
params BodyT (Lore m)
body [Type]
_) [Exp (Lore m)]
args = do
  [(Param Type, Exp (Lore m))]
-> ((Param Type, Exp (Lore m)) -> m [Ident]) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [Exp (Lore m)] -> [(Param Type, Exp (Lore m))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
[LParam (Lore m)]
params [Exp (Lore m)]
args) (((Param Type, Exp (Lore m)) -> m [Ident]) -> m ())
-> ((Param Type, Exp (Lore m)) -> m [Ident]) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
param, Exp (Lore m)
arg) ->
    if Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param Type -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param Type
param
    then [VName] -> Exp (Lore m) -> m [Ident]
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m [Ident]
letBindNames [Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
param] Exp (Lore m)
arg
    else [VName] -> Exp (Lore m) -> m [Ident]
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m [Ident]
letBindNames [Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
param] (Exp (Lore m) -> m [Ident]) -> m (Exp (Lore m)) -> m [Ident]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m)) -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Exp (Lore m)) -> m (Exp (Lore m))
eCopy (Exp (Lore m) -> m (Exp (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp (Lore m)
arg)
  BodyT (Lore m) -> m Result
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m Result
bodyBind BodyT (Lore m)
body

loopMerge :: [Ident] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge :: [Ident] -> Result -> [(Param DeclType, SubExp)]
loopMerge [Ident]
vars = [(Ident, Uniqueness)] -> Result -> [(Param DeclType, SubExp)]
loopMerge' ([(Ident, Uniqueness)] -> Result -> [(Param DeclType, SubExp)])
-> [(Ident, Uniqueness)] -> Result -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ [Ident] -> [Uniqueness] -> [(Ident, Uniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
vars ([Uniqueness] -> [(Ident, Uniqueness)])
-> [Uniqueness] -> [(Ident, Uniqueness)]
forall a b. (a -> b) -> a -> b
$ Uniqueness -> [Uniqueness]
forall a. a -> [a]
repeat Uniqueness
Unique

loopMerge' :: [(Ident,Uniqueness)] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge' :: [(Ident, Uniqueness)] -> Result -> [(Param DeclType, SubExp)]
loopMerge' [(Ident, Uniqueness)]
vars Result
vals = [ (VName -> DeclType -> Param DeclType
forall attr. VName -> attr -> Param attr
Param VName
pname (DeclType -> Param DeclType) -> DeclType -> Param DeclType
forall a b. (a -> b) -> a -> b
$ Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
ptype Uniqueness
u, SubExp
val)
                       | ((Ident VName
pname Type
ptype, Uniqueness
u),SubExp
val) <- [(Ident, Uniqueness)] -> Result -> [((Ident, Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Ident, Uniqueness)]
vars Result
vals ]