{-# LANGUAGE TypeFamilies #-}
module Futhark.Transform.FirstOrderTransform
( transformFunDef,
transformConsts,
FirstOrderRep,
Transformer,
transformStmRecursively,
transformLambda,
transformSOAC,
)
where
import Control.Monad.Except
import Control.Monad.State
import Data.List (find, zip4)
import Data.Map.Strict qualified as M
import Futhark.Analysis.Alias qualified as Alias
import Futhark.IR qualified as AST
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Util (chunks, splitAt3)
type FirstOrderRep rep =
( Buildable rep,
BuilderOps rep,
LetDec SOACS ~ LetDec rep,
LParamInfo SOACS ~ LParamInfo rep,
CanBeAliased (Op rep)
)
transformFunDef ::
(MonadFreshNames m, FirstOrderRep torep) =>
Scope torep ->
FunDef SOACS ->
m (AST.FunDef torep)
transformFunDef :: forall (m :: * -> *) torep.
(MonadFreshNames m, FirstOrderRep torep) =>
Scope torep -> FunDef SOACS -> m (FunDef torep)
transformFunDef Scope torep
consts_scope (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType SOACS]
rettype [FParam SOACS]
params Body SOACS
body) = do
(Body torep
body', Stms torep
_) <- forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ forall s a. State s a -> s -> (a, s)
runState forall a b. (a -> b) -> a -> b
$ forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT torep (StateT VNameSource Identity) (Body torep)
m Scope torep
consts_scope
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType SOACS]
rettype [FParam SOACS]
params Body torep
body'
where
m :: BuilderT torep (StateT VNameSource Identity) (Body torep)
m = forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams [FParam SOACS]
params) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Body SOACS -> m (Body (Rep m))
transformBody Body SOACS
body
transformConsts ::
(MonadFreshNames m, FirstOrderRep torep) =>
Stms SOACS ->
m (AST.Stms torep)
transformConsts :: forall (m :: * -> *) torep.
(MonadFreshNames m, FirstOrderRep torep) =>
Stms SOACS -> m (Stms torep)
transformConsts Stms SOACS
stms =
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ forall s a. State s a -> s -> (a, s)
runState forall a b. (a -> b) -> a -> b
$ forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT torep (StateT VNameSource Identity) ()
m forall a. Monoid a => a
mempty
where
m :: BuilderT torep (StateT VNameSource Identity) ()
m = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Stm SOACS -> m ()
transformStmRecursively Stms SOACS
stms
type Transformer m =
( MonadBuilder m,
LocalScope (Rep m) m,
Buildable (Rep m),
BuilderOps (Rep m),
LParamInfo SOACS ~ LParamInfo (Rep m),
CanBeAliased (Op (Rep m))
)
transformBody ::
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Body SOACS ->
m (AST.Body (Rep m))
transformBody :: forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Body SOACS -> m (Body (Rep m))
transformBody (Body () Stms SOACS
stms Result
res) = forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Stm SOACS -> m ()
transformStmRecursively Stms SOACS
stms
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
transformStmRecursively ::
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) => Stm SOACS -> m ()
transformStmRecursively :: forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Stm SOACS -> m ()
transformStmRecursively (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac)) =
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
transformSOAC Pat (LetDec SOACS)
pat forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper SOACS (Rep m) m
soacTransform Op SOACS
soac
where
soacTransform :: SOACMapper SOACS (Rep m) m
soacTransform = forall {k} (m :: * -> *) (rep :: k).
Monad m =>
SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda SOACS -> m (Lambda (Rep m))
mapOnSOACLambda = forall {k} (m :: * -> *) rep (somerep :: k).
(MonadFreshNames m, Buildable rep, BuilderOps rep,
LocalScope somerep m, SameScope somerep rep,
LetDec rep ~ LetDec SOACS, CanBeAliased (Op rep)) =>
Lambda SOACS -> m (Lambda rep)
transformLambda}
transformStmRecursively (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) =
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec SOACS)
pat forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper SOACS (Rep m) m
transform Exp SOACS
e
where
transform :: Mapper SOACS (Rep m) m
transform =
forall {k} (m :: * -> *) (rep :: k). Monad m => Mapper rep rep m
identityMapper
{ mapOnBody :: Scope (Rep m) -> Body SOACS -> m (Body (Rep m))
mapOnBody = \Scope (Rep m)
scope -> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope (Rep m)
scope forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Body SOACS -> m (Body (Rep m))
transformBody,
mapOnRetType :: RetType SOACS -> m (RetType (Rep m))
mapOnRetType = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnBranchType :: BranchType SOACS -> m (BranchType (Rep m))
mapOnBranchType = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnFParam :: FParam SOACS -> m (FParam (Rep m))
mapOnFParam = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnLParam :: LParam SOACS -> m (LParam (Rep m))
mapOnLParam = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnOp :: Op SOACS -> m (Op (Rep m))
mapOnOp = forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled Op in first order transform"
}
resultArray :: Transformer m => [VName] -> [Type] -> m [VName]
resultArray :: forall (m :: * -> *).
Transformer m =>
[VName] -> [Type] -> m [VName]
resultArray [VName]
arrs [Type]
ts = do
[Type]
arrs_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType [VName]
arrs
let oneArray :: Type -> m VName
oneArray t :: Type
t@Acc {}
| Just (VName
v, Type
_) <- forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== Type
t) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [Type]
arrs_ts) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
oneArray Type
t =
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"result" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank Type
t
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> m VName
oneArray [Type]
ts
transformSOAC ::
Transformer m =>
Pat (LetDec (Rep m)) ->
SOAC (Rep m) ->
m ()
transformSOAC :: forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
transformSOAC Pat (LetDec (Rep m))
_ JVP {} =
forall a. HasCallStack => [Char] -> a
error [Char]
"transformSOAC: unhandled JVP"
transformSOAC Pat (LetDec (Rep m))
_ VJP {} =
forall a. HasCallStack => [Char] -> a
error [Char]
"transformSOAC: unhandled VJP"
transformSOAC Pat (LetDec (Rep m))
pat (Screma SubExp
w [VName]
arrs form :: ScremaForm (Rep m)
form@(ScremaForm [Scan (Rep m)]
scans [Reduce (Rep m)]
reds Lambda (Rep m)
map_lam)) = do
let Reduce Commutativity
_ Lambda (Rep m)
red_lam [SubExp]
red_nes = forall {k} (rep :: k). Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce (Rep m)]
reds
Scan Lambda (Rep m)
scan_lam [SubExp]
scan_nes = forall {k} (rep :: k). Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan (Rep m)]
scans
([Type]
scan_arr_ts, [Type]
_red_ts, [Type]
map_arr_ts) =
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SubExp -> ScremaForm rep -> [Type]
scremaType SubExp
w ScremaForm (Rep m)
form
[VName]
scan_arrs <- forall (m :: * -> *).
Transformer m =>
[VName] -> [Type] -> m [VName]
resultArray [] [Type]
scan_arr_ts
[VName]
map_arrs <- forall (m :: * -> *).
Transformer m =>
[VName] -> [Type] -> m [VName]
resultArray [VName]
arrs [Type]
map_arr_ts
[Param DeclType]
scanacc_params <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"scanacc" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Rep m)
scan_lam
[Param DeclType]
scanout_params <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"scanout" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Unique) [Type]
scan_arr_ts
[Param DeclType]
redout_params <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"redout" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Rep m)
red_lam
[Param DeclType]
mapout_params <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"mapout" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Unique) [Type]
map_arr_ts
[Type]
arr_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType [VName]
arrs
let paramForAcc :: Type -> Maybe (Param DeclType)
paramForAcc (Acc VName
c ShapeBase SubExp
_ [Type]
_ NoUniqueness
_) = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (Type -> Bool
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Typed dec => Param dec -> Type
paramType) [Param DeclType]
mapout_params
where
f :: Type -> Bool
f (Acc VName
c2 ShapeBase SubExp
_ [Type]
_ NoUniqueness
_) = VName
c forall a. Eq a => a -> a -> Bool
== VName
c2
f Type
_ = Bool
False
paramForAcc Type
_ = forall a. Maybe a
Nothing
let merge :: [(Param DeclType, SubExp)]
merge =
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
[ forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
scanacc_params [SubExp]
scan_nes,
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
scanout_params forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
scan_arrs,
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
redout_params [SubExp]
red_nes,
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
mapout_params forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
map_arrs
]
VName
i <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i"
let loopform :: LoopForm (Rep m)
loopform = forall {k} (rep :: k).
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
Int64 SubExp
w []
lam_cons :: Names
lam_cons = forall {k} (rep :: k). Aliased rep => Lambda rep -> Names
consumedByLambda forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda forall a. Monoid a => a
mempty Lambda (Rep m)
map_lam
Body (Rep m)
loop_body <- forall {k1} {k2} (rep :: k1) (m :: * -> *) (somerep :: k2).
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf LoopForm (Rep m)
loopform)
forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
map_lam) [VName]
arrs [Type]
arr_ts) forall a b. (a -> b) -> a -> b
$ \(Param (LParamInfo (Rep m))
p, VName
arr, Type
arr_t) ->
case Type -> Maybe (Param DeclType)
paramForAcc Type
arr_t of
Just Param DeclType
acc_out_p ->
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
p] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$
VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$
forall dec. Param dec -> VName
paramName Param DeclType
acc_out_p
Maybe (Param DeclType)
Nothing
| forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
p VName -> Names -> Bool
`nameIn` Names
lam_cons -> do
VName
p' <-
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString (forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
p)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i]
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
p] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
p'
| Bool
otherwise ->
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
p] forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i]
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
map_lam
let (Result
scan_res, Result
red_res, Result
map_res) =
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
map_lam
Result
scan_res' <-
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep m)
scan_lam forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param DeclType]
scanacc_params forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
scan_res
Result
red_res' <-
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep m)
red_lam forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param DeclType]
redout_params forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
red_res
[VName]
scan_outarrs <-
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts Result
scan_res) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
Transformer m =>
[VName] -> SubExp -> [SubExp] -> m [VName]
letwith (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param DeclType]
scanout_params) (VName -> SubExp
Var VName
i) forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
scan_res'
[VName]
map_outarrs <-
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts Result
map_res) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
Transformer m =>
[VName] -> SubExp -> [SubExp] -> m [VName]
letwith (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param DeclType]
mapout_params) (VName -> SubExp
Var VName
i) forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
map_res
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody forall a. Monoid a => a
mempty forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$
[ Result
scan_res',
[VName] -> Result
varsRes [VName]
scan_outarrs,
Result
red_res',
[VName] -> Result
varsRes [VName]
map_outarrs
]
[VName]
names <-
(forall a. [a] -> [a] -> [a]
++ forall dec. Pat dec -> [VName]
patNames Pat (LetDec (Rep m))
pat)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param DeclType]
scanacc_params) (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"discard")
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
names forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
merge LoopForm (Rep m)
loopform Body (Rep m)
loop_body
transformSOAC Pat (LetDec (Rep m))
pat (Stream SubExp
w [VName]
arrs [SubExp]
nes Lambda (Rep m)
lam) = do
let (Param Type
chunk_size_param, [Param Type]
fold_params, [Param Type]
chunk_params) =
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam
[(Param DeclType, SubExp)]
mapout_merge <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Rep m)
lam) forall a b. (a -> b) -> a -> b
$ \Type
t ->
let t' :: Type
t' = Type
t forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
scratch :: Exp (Rep m)
scratch = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch (forall shape u. TypeBase shape u -> PrimType
elemType Type
t') (forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
t')
in (,)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"stream_mapout" (forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
t' Uniqueness
Unique)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"stream_mapout_scratch" Exp (Rep m)
scratch
let copyIfArray :: SubExp -> m SubExp
copyIfArray SubExp
se = do
Type
se_t <- forall {k} (t :: k) (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
case (Type
se_t, SubExp
se) of
(Array {}, Var VName
v) -> forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
v) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
(Type, SubExp)
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
[SubExp]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {m :: * -> *}. MonadBuilder m => SubExp -> m SubExp
copyIfArray [SubExp]
nes
let onType :: TypeBase shape NoUniqueness -> TypeBase shape Uniqueness
onType TypeBase shape NoUniqueness
t = TypeBase shape NoUniqueness
t forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
`toDecl` Uniqueness
Unique
merge :: [(Param DeclType, SubExp)]
merge = forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {shape}.
TypeBase shape NoUniqueness -> TypeBase shape Uniqueness
onType) [Param Type]
fold_params) [SubExp]
nes' forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
mapout_merge
merge_params :: [Param DeclType]
merge_params = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge
mapout_params :: [Param DeclType]
mapout_params = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
mapout_merge
VName
i <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i"
let loop_form :: LoopForm (Rep m)
loop_form = forall {k} (rep :: k).
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
Int64 SubExp
w []
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param Type
chunk_size_param] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$
IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
Body (Rep m)
loop_body <- forall {k1} {k2} (rep :: k1) (m :: * -> *) (somerep :: k2).
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf LoopForm (Rep m)
loop_form forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
merge_params) forall a b. (a -> b) -> a -> b
$ do
let slice :: [DimIndex SubExp]
slice = [forall d. d -> d -> d -> DimIndex d
DimSlice (VName -> SubExp
Var VName
i) (VName -> SubExp
Var (forall dec. Param dec -> VName
paramName Param Type
chunk_size_param)) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
chunk_params [VName]
arrs) forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) ->
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param Type
p] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) [DimIndex SubExp]
slice
(Result
res, Result
mapout_res) <- forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
lam)
[SubExp]
res' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {m :: * -> *}. MonadBuilder m => SubExp -> m SubExp
copyIfArray forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
res
[SubExp]
mapout_res' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
mapout_params Result
mapout_res) forall a b. (a -> b) -> a -> b
$ \(Param DeclType
p, SubExpRes Certs
cs SubExp
se) ->
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"mapout_res" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe (forall dec. Param dec -> VName
paramName Param DeclType
p) (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (forall dec. Typed dec => Param dec -> Type
paramType Param DeclType
p) [DimIndex SubExp]
slice) SubExp
se
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes forall a b. (a -> b) -> a -> b
$ [SubExp]
res' forall a. [a] -> [a] -> [a]
++ [SubExp]
mapout_res'
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep m))
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
merge LoopForm (Rep m)
loop_form Body (Rep m)
loop_body
transformSOAC Pat (LetDec (Rep m))
pat (Scatter SubExp
len [VName]
ivs Lambda (Rep m)
lam [(ShapeBase SubExp, Int, VName)]
as) = do
VName
iter <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"write_iter"
let ([ShapeBase SubExp]
as_ws, [Int]
as_ns, [VName]
as_vs) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(ShapeBase SubExp, Int, VName)]
as
[Type]
ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType [VName]
as_vs
[Ident]
asOuts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent [Char]
"write_out") [Type]
ts
let merge :: [(Param DeclType, SubExp)]
merge = [Ident] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge [Ident]
asOuts forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
as_vs
Body (Rep m)
loopBody <- forall {k1} {k2} (rep :: k1) (m :: * -> *) (somerep :: k2).
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
iter (forall {k} (rep :: k). IntType -> NameInfo rep
IndexName IntType
Int64) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge) forall a b. (a -> b) -> a -> b
$ do
[SubExp]
ivs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ivs forall a b. (a -> b) -> a -> b
$ \VName
iv -> do
Type
iv_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
iv
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"write_iv" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
iv forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
iv_t [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
iter]
Result
ivs'' <- forall (m :: * -> *).
Transformer m =>
Lambda (Rep m) -> [Exp (Rep m)] -> m Result
bindLambda Lambda (Rep m)
lam (forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) [SubExp]
ivs')
let indexes :: [(ShapeBase SubExp, VName, [(Result, SubExpRes)])]
indexes = forall array a.
[(ShapeBase SubExp, Int, array)]
-> [a] -> [(ShapeBase SubExp, array, [([a], a)])]
groupScatterResults (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [ShapeBase SubExp]
as_ws [Int]
as_ns forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
asOuts) Result
ivs''
[VName]
ress <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ShapeBase SubExp, VName, [(Result, SubExpRes)])]
indexes forall a b. (a -> b) -> a -> b
$ \(ShapeBase SubExp
_, VName
arr, [(Result, SubExpRes)]
indexes') -> do
Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
let saveInArray :: VName
-> (Result, SubExpRes)
-> BuilderT (Rep m) (State VNameSource) VName
saveInArray VName
arr' (Result
indexCur, SubExpRes Certs
value_cs SubExp
valueCur) =
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts Result
indexCur forall a. Semigroup a => a -> a -> a
<> Certs
value_cs) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"write_out" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Safe VName
arr' (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
indexCur) SubExp
valueCur
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM VName
-> (Result, SubExpRes)
-> BuilderT (Rep m) (State VNameSource) VName
saveInArray VName
arr [(Result, SubExpRes)]
indexes'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Buildable rep => [SubExp] -> Body rep
resultBody (forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
ress)
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep m))
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
merge (forall {k} (rep :: k).
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
iter IntType
Int64 SubExp
len []) Body (Rep m)
loopBody
transformSOAC Pat (LetDec (Rep m))
pat (Hist SubExp
len [VName]
imgs [HistOp (Rep m)]
ops Lambda (Rep m)
bucket_fun) = do
VName
iter <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"iter"
[Type]
hists_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). HistOp rep -> [VName]
histDest [HistOp (Rep m)]
ops
[Ident]
hists_out <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent [Char]
"dests") [Type]
hists_ts
let merge :: [(Param DeclType, SubExp)]
merge = [Ident] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge [Ident]
hists_out forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> [VName]
histDest) [HistOp (Rep m)]
ops
let iter_scope :: Scope (Rep m)
iter_scope = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
iter (forall {k} (rep :: k). IntType -> NameInfo rep
IndexName IntType
Int64) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
merge
Body (Rep m)
loopBody <- forall {k1} {k2} (rep :: k1) (m :: * -> *) (somerep :: k2).
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope (Rep m)
iter_scope forall a b. (a -> b) -> a -> b
$ do
[SubExp]
imgs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
imgs forall a b. (a -> b) -> a -> b
$ \VName
img -> do
Type
img_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
img
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"pixel" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
img forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
img_t [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
iter]
[SubExp]
imgs'' <- forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
Transformer m =>
Lambda (Rep m) -> [Exp (Rep m)] -> m Result
bindLambda Lambda (Rep m)
bucket_fun (forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) [SubExp]
imgs')
let lens :: Int
lens = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a. ArrayShape a => a -> Int
shapeRank forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histShape) [HistOp (Rep m)]
ops
ops_inds :: [[SubExp]]
ops_inds = forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall a. ArrayShape a => a -> Int
shapeRank forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histShape) [HistOp (Rep m)]
ops) (forall a. Int -> [a] -> [a]
take Int
lens [SubExp]
imgs'')
vals :: [[SubExp]]
vals = forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp) [HistOp (Rep m)]
ops) forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
drop Int
lens [SubExp]
imgs''
hists_out' :: [[VName]]
hists_out' =
forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp) [HistOp (Rep m)]
ops) forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
hists_out
[[VName]]
hists_out'' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [[VName]]
hists_out' [HistOp (Rep m)]
ops [[SubExp]]
ops_inds [[SubExp]]
vals) forall a b. (a -> b) -> a -> b
$ \([VName]
hist, HistOp (Rep m)
op, [SubExp]
idxs, [SubExp]
val) -> do
let outside_bounds_branch :: BuilderT
(Rep m)
(State VNameSource)
(Body (Rep (BuilderT (Rep m) (State VNameSource))))
outside_bounds_branch = forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes [VName]
hist
oob :: BuilderT
(Rep m)
(State VNameSource)
(Exp (Rep (BuilderT (Rep m) (State VNameSource))))
oob = case [VName]
hist of
[] -> forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ forall v. IsValue v => v -> SubExp
constant Bool
True
VName
arr : [VName]
_ -> forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eOutOfBounds VName
arr forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp [SubExp]
idxs
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"new_histo" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf BuilderT
(Rep m)
(State VNameSource)
(Exp (Rep (BuilderT (Rep m) (State VNameSource))))
oob BuilderT
(Rep m)
(State VNameSource)
(Body (Rep (BuilderT (Rep m) (State VNameSource))))
outside_bounds_branch forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
[SubExp]
h_val <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
hist forall a b. (a -> b) -> a -> b
$ \VName
arr -> do
Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"read_hist" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [SubExp]
idxs
Result
h_val' <- forall (m :: * -> *).
Transformer m =>
Lambda (Rep m) -> [Exp (Rep m)] -> m Result
bindLambda (forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp (Rep m)
op) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) forall a b. (a -> b) -> a -> b
$ [SubExp]
h_val forall a. [a] -> [a] -> [a]
++ [SubExp]
val
[VName]
hist' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
hist Result
h_val') forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExpRes Certs
cs SubExp
v) -> do
Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
letInPlace [Char]
"hist_out" VName
arr (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [SubExp]
idxs) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
SubExp -> BasicOp
SubExp SubExp
v
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes [VName]
hist'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Buildable rep => [SubExp] -> Body rep
resultBody forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
hists_out''
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep m))
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
merge (forall {k} (rep :: k).
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
iter IntType
Int64 SubExp
len []) Body (Rep m)
loopBody
transformLambda ::
( MonadFreshNames m,
Buildable rep,
BuilderOps rep,
LocalScope somerep m,
SameScope somerep rep,
LetDec rep ~ LetDec SOACS,
CanBeAliased (Op rep)
) =>
Lambda SOACS ->
m (AST.Lambda rep)
transformLambda :: forall {k} (m :: * -> *) rep (somerep :: k).
(MonadFreshNames m, Buildable rep, BuilderOps rep,
LocalScope somerep m, SameScope somerep rep,
LetDec rep ~ LetDec SOACS, CanBeAliased (Op rep)) =>
Lambda SOACS -> m (Lambda rep)
transformLambda (Lambda [LParam SOACS]
params Body SOACS
body [Type]
rettype) = do
Body rep
body' <-
forall {k1} {k2} (rep :: k1) (m :: * -> *) (somerep :: k2).
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams [LParam SOACS]
params) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Body SOACS -> m (Body (Rep m))
transformBody Body SOACS
body
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [LParam SOACS]
params Body rep
body' [Type]
rettype
letwith :: Transformer m => [VName] -> SubExp -> [SubExp] -> m [VName]
letwith :: forall (m :: * -> *).
Transformer m =>
[VName] -> SubExp -> [SubExp] -> m [VName]
letwith [VName]
ks SubExp
i [SubExp]
vs = do
let update :: VName -> SubExp -> m VName
update VName
k SubExp
v = do
Type
k_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
k
case Type
k_t of
Acc {} ->
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"lw_acc" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v
Type
_ ->
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
letInPlace [Char]
"lw_dest" VName
k (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
k_t [forall d. d -> DimIndex d
DimFix SubExp
i]) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM VName -> SubExp -> m VName
update [VName]
ks [SubExp]
vs
bindLambda ::
Transformer m =>
AST.Lambda (Rep m) ->
[AST.Exp (Rep m)] ->
m Result
bindLambda :: forall (m :: * -> *).
Transformer m =>
Lambda (Rep m) -> [Exp (Rep m)] -> m Result
bindLambda (Lambda [LParam (Rep m)]
params Body (Rep m)
body [Type]
_) [Exp (Rep m)]
args = do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [LParam (Rep m)]
params [Exp (Rep m)]
args) forall a b. (a -> b) -> a -> b
$ \(Param Type
param, Exp (Rep m)
arg) ->
if forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param Type
param
then forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param Type
param] Exp (Rep m)
arg
else forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param Type
param] forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
m (Exp (Rep m)) -> m (Exp (Rep m))
eCopy (forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp (Rep m)
arg)
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body (Rep m)
body
loopMerge :: [Ident] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge :: [Ident] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge [Ident]
vars = [(Ident, Uniqueness)] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge' forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
vars forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat Uniqueness
Unique
loopMerge' :: [(Ident, Uniqueness)] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge' :: [(Ident, Uniqueness)] -> [SubExp] -> [(Param DeclType, SubExp)]
loopMerge' [(Ident, Uniqueness)]
vars [SubExp]
vals =
[ (forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
pname forall a b. (a -> b) -> a -> b
$ forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
ptype Uniqueness
u, SubExp
val)
| ((Ident VName
pname Type
ptype, Uniqueness
u), SubExp
val) <- forall a b. [a] -> [b] -> [(a, b)]
zip [(Ident, Uniqueness)]
vars [SubExp]
vals
]