{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Futhark.IR.SegOp
( SegOp (..),
segLevel,
segBody,
segSpace,
typeCheckSegOp,
SegSpace (..),
scopeOfSegSpace,
segSpaceDims,
HistOp (..),
histType,
splitHistResults,
SegBinOp (..),
segBinOpResults,
segBinOpChunks,
KernelBody (..),
aliasAnalyseKernelBody,
consumedInKernelBody,
ResultManifest (..),
KernelResult (..),
kernelResultCerts,
kernelResultSubExp,
SegOpMapper (..),
identitySegOpMapper,
mapSegOpM,
traverseSegOpStms,
simplifySegOp,
HasSegOp (..),
segOpRules,
segOpReturns,
)
where
import Control.Category
import Control.Monad.Identity hiding (mapM_)
import Control.Monad.Reader hiding (mapM_)
import Control.Monad.State.Strict
import Control.Monad.Writer hiding (mapM_)
import Data.Bifunctor (first)
import Data.Bitraversable
import Data.Foldable (traverse_)
import Data.List
( elemIndex,
foldl',
groupBy,
intersperse,
isPrefixOf,
partition,
)
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR
import Futhark.IR.Aliases
( Aliases,
CanBeAliased (..),
)
import Futhark.IR.Mem
import Futhark.IR.Prop.Aliases
import Futhark.IR.TypeCheck qualified as TC
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (chunks, maybeNth)
import Futhark.Util.Pretty
( Doc,
apply,
hsep,
parens,
ppTuple',
pretty,
(<+>),
(</>),
)
import Futhark.Util.Pretty qualified as PP
import Prelude hiding (id, (.))
data HistOp rep = HistOp
{ forall rep. HistOp rep -> ShapeBase SubExp
histShape :: Shape,
forall rep. HistOp rep -> SubExp
histRaceFactor :: SubExp,
forall rep. HistOp rep -> [VName]
histDest :: [VName],
forall rep. HistOp rep -> [SubExp]
histNeutral :: [SubExp],
forall rep. HistOp rep -> ShapeBase SubExp
histOpShape :: Shape,
forall rep. HistOp rep -> Lambda rep
histOp :: Lambda rep
}
deriving (HistOp rep -> HistOp rep -> Bool
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HistOp rep -> HistOp rep -> Bool
$c/= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
== :: HistOp rep -> HistOp rep -> Bool
$c== :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
Eq, HistOp rep -> HistOp rep -> Bool
HistOp rep -> HistOp rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
min :: HistOp rep -> HistOp rep -> HistOp rep
$cmin :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
max :: HistOp rep -> HistOp rep -> HistOp rep
$cmax :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
>= :: HistOp rep -> HistOp rep -> Bool
$c>= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
> :: HistOp rep -> HistOp rep -> Bool
$c> :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
<= :: HistOp rep -> HistOp rep -> Bool
$c<= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
< :: HistOp rep -> HistOp rep -> Bool
$c< :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
compare :: HistOp rep -> HistOp rep -> Ordering
$ccompare :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
Ord, Int -> HistOp rep -> ShowS
forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
forall rep. RepTypes rep => [HistOp rep] -> ShowS
forall rep. RepTypes rep => HistOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HistOp rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [HistOp rep] -> ShowS
show :: HistOp rep -> String
$cshow :: forall rep. RepTypes rep => HistOp rep -> String
showsPrec :: Int -> HistOp rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
Show)
histType :: HistOp rep -> [Type]
histType :: forall rep. HistOp rep -> [Type]
histType HistOp rep
op =
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` (forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp rep
op forall a. Semigroup a => a -> a -> a
<> forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp rep
op)) forall a b. (a -> b) -> a -> b
$
forall rep. Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$
forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op
splitHistResults :: [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults :: forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp rep]
ops [SubExp]
res =
let ranks :: [Int]
ranks = forall a b. (a -> b) -> [a] -> [b]
map (forall a. ArrayShape a => a -> Int
shapeRank forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. HistOp rep -> ShapeBase SubExp
histShape) [HistOp rep]
ops
([SubExp]
idxs, [SubExp]
vals) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
ranks) [SubExp]
res
in forall a b. [a] -> [b] -> [(a, b)]
zip
(forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ranks [SubExp]
idxs)
(forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. HistOp rep -> [VName]
histDest) [HistOp rep]
ops) [SubExp]
vals)
data SegBinOp rep = SegBinOp
{ forall rep. SegBinOp rep -> Commutativity
segBinOpComm :: Commutativity,
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda :: Lambda rep,
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral :: [SubExp],
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape :: Shape
}
deriving (SegBinOp rep -> SegBinOp rep -> Bool
forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegBinOp rep -> SegBinOp rep -> Bool
$c/= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
== :: SegBinOp rep -> SegBinOp rep -> Bool
$c== :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
Eq, SegBinOp rep -> SegBinOp rep -> Bool
SegBinOp rep -> SegBinOp rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (SegBinOp rep)
forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Ordering
forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
min :: SegBinOp rep -> SegBinOp rep -> SegBinOp rep
$cmin :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
max :: SegBinOp rep -> SegBinOp rep -> SegBinOp rep
$cmax :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
>= :: SegBinOp rep -> SegBinOp rep -> Bool
$c>= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
> :: SegBinOp rep -> SegBinOp rep -> Bool
$c> :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
<= :: SegBinOp rep -> SegBinOp rep -> Bool
$c<= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
< :: SegBinOp rep -> SegBinOp rep -> Bool
$c< :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
compare :: SegBinOp rep -> SegBinOp rep -> Ordering
$ccompare :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Ordering
Ord, Int -> SegBinOp rep -> ShowS
forall rep. RepTypes rep => Int -> SegBinOp rep -> ShowS
forall rep. RepTypes rep => [SegBinOp rep] -> ShowS
forall rep. RepTypes rep => SegBinOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegBinOp rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [SegBinOp rep] -> ShowS
show :: SegBinOp rep -> String
$cshow :: forall rep. RepTypes rep => SegBinOp rep -> String
showsPrec :: Int -> SegBinOp rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> SegBinOp rep -> ShowS
Show)
segBinOpResults :: [SegBinOp rep] -> Int
segBinOpResults :: forall rep. [SegBinOp rep] -> Int
segBinOpResults = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral)
segBinOpChunks :: [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks :: forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks = forall a. [Int] -> [a] -> [[a]]
chunks forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral)
data KernelBody rep = KernelBody
{ forall rep. KernelBody rep -> BodyDec rep
kernelBodyDec :: BodyDec rep,
forall rep. KernelBody rep -> Stms rep
kernelBodyStms :: Stms rep,
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult :: [KernelResult]
}
deriving instance RepTypes rep => Ord (KernelBody rep)
deriving instance RepTypes rep => Show (KernelBody rep)
deriving instance RepTypes rep => Eq (KernelBody rep)
data ResultManifest
=
ResultNoSimplify
|
ResultMaySimplify
|
ResultPrivate
deriving (ResultManifest -> ResultManifest -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ResultManifest -> ResultManifest -> Bool
$c/= :: ResultManifest -> ResultManifest -> Bool
== :: ResultManifest -> ResultManifest -> Bool
$c== :: ResultManifest -> ResultManifest -> Bool
Eq, Int -> ResultManifest -> ShowS
[ResultManifest] -> ShowS
ResultManifest -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ResultManifest] -> ShowS
$cshowList :: [ResultManifest] -> ShowS
show :: ResultManifest -> String
$cshow :: ResultManifest -> String
showsPrec :: Int -> ResultManifest -> ShowS
$cshowsPrec :: Int -> ResultManifest -> ShowS
Show, Eq ResultManifest
ResultManifest -> ResultManifest -> Bool
ResultManifest -> ResultManifest -> Ordering
ResultManifest -> ResultManifest -> ResultManifest
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ResultManifest -> ResultManifest -> ResultManifest
$cmin :: ResultManifest -> ResultManifest -> ResultManifest
max :: ResultManifest -> ResultManifest -> ResultManifest
$cmax :: ResultManifest -> ResultManifest -> ResultManifest
>= :: ResultManifest -> ResultManifest -> Bool
$c>= :: ResultManifest -> ResultManifest -> Bool
> :: ResultManifest -> ResultManifest -> Bool
$c> :: ResultManifest -> ResultManifest -> Bool
<= :: ResultManifest -> ResultManifest -> Bool
$c<= :: ResultManifest -> ResultManifest -> Bool
< :: ResultManifest -> ResultManifest -> Bool
$c< :: ResultManifest -> ResultManifest -> Bool
compare :: ResultManifest -> ResultManifest -> Ordering
$ccompare :: ResultManifest -> ResultManifest -> Ordering
Ord)
data KernelResult
=
Returns ResultManifest Certs SubExp
| WriteReturns
Certs
Shape
VName
[(Slice SubExp, SubExp)]
| TileReturns
Certs
[(SubExp, SubExp)]
VName
| RegTileReturns
Certs
[ ( SubExp,
SubExp,
SubExp
)
]
VName
deriving (KernelResult -> KernelResult -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelResult -> KernelResult -> Bool
$c/= :: KernelResult -> KernelResult -> Bool
== :: KernelResult -> KernelResult -> Bool
$c== :: KernelResult -> KernelResult -> Bool
Eq, Int -> KernelResult -> ShowS
[KernelResult] -> ShowS
KernelResult -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelResult] -> ShowS
$cshowList :: [KernelResult] -> ShowS
show :: KernelResult -> String
$cshow :: KernelResult -> String
showsPrec :: Int -> KernelResult -> ShowS
$cshowsPrec :: Int -> KernelResult -> ShowS
Show, Eq KernelResult
KernelResult -> KernelResult -> Bool
KernelResult -> KernelResult -> Ordering
KernelResult -> KernelResult -> KernelResult
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelResult -> KernelResult -> KernelResult
$cmin :: KernelResult -> KernelResult -> KernelResult
max :: KernelResult -> KernelResult -> KernelResult
$cmax :: KernelResult -> KernelResult -> KernelResult
>= :: KernelResult -> KernelResult -> Bool
$c>= :: KernelResult -> KernelResult -> Bool
> :: KernelResult -> KernelResult -> Bool
$c> :: KernelResult -> KernelResult -> Bool
<= :: KernelResult -> KernelResult -> Bool
$c<= :: KernelResult -> KernelResult -> Bool
< :: KernelResult -> KernelResult -> Bool
$c< :: KernelResult -> KernelResult -> Bool
compare :: KernelResult -> KernelResult -> Ordering
$ccompare :: KernelResult -> KernelResult -> Ordering
Ord)
kernelResultCerts :: KernelResult -> Certs
kernelResultCerts :: KernelResult -> Certs
kernelResultCerts (Returns ResultManifest
_ Certs
cs SubExp
_) = Certs
cs
kernelResultCerts (WriteReturns Certs
cs ShapeBase SubExp
_ VName
_ [(Slice SubExp, SubExp)]
_) = Certs
cs
kernelResultCerts (TileReturns Certs
cs [(SubExp, SubExp)]
_ VName
_) = Certs
cs
kernelResultCerts (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
_ VName
_) = Certs
cs
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp (Returns ResultManifest
_ Certs
_ SubExp
se) = SubExp
se
kernelResultSubExp (WriteReturns Certs
_ ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) = VName -> SubExp
Var VName
arr
kernelResultSubExp (TileReturns Certs
_ [(SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v
kernelResultSubExp (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v
instance FreeIn KernelResult where
freeIn' :: KernelResult -> FV
freeIn' (Returns ResultManifest
_ Certs
cs SubExp
what) = forall a. FreeIn a => a -> FV
freeIn' Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' SubExp
what
freeIn' (WriteReturns Certs
cs ShapeBase SubExp
rws VName
arr [(Slice SubExp, SubExp)]
res) = forall a. FreeIn a => a -> FV
freeIn' Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' ShapeBase SubExp
rws forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' VName
arr forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [(Slice SubExp, SubExp)]
res
freeIn' (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
forall a. FreeIn a => a -> FV
freeIn' Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp)]
dims forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' VName
v
freeIn' (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
forall a. FreeIn a => a -> FV
freeIn' Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp, SubExp)]
dims_n_tiles forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' VName
v
instance ASTRep rep => FreeIn (KernelBody rep) where
freeIn' :: KernelBody rep -> FV
freeIn' (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
Names -> FV -> FV
fvBind Names
bound_in_stms forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> FV
freeIn' BodyDec rep
dec forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' Stms rep
stms forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [KernelResult]
res
where
bound_in_stms :: Names
bound_in_stms = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall rep. Stm rep -> Names
boundByStm Stms rep
stms
instance ASTRep rep => Substitute (KernelBody rep) where
substituteNames :: Map VName VName -> KernelBody rep -> KernelBody rep
substituteNames Map VName VName
subst (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst BodyDec rep
dec)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Stms rep
stms)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [KernelResult]
res)
instance Substitute KernelResult where
substituteNames :: Map VName VName -> KernelResult -> KernelResult
substituteNames Map VName VName
subst (Returns ResultManifest
manifest Certs
cs SubExp
se) =
ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs) (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
se)
substituteNames Map VName VName
subst (WriteReturns Certs
cs ShapeBase SubExp
rws VName
arr [(Slice SubExp, SubExp)]
res) =
Certs
-> ShapeBase SubExp
-> VName
-> [(Slice SubExp, SubExp)]
-> KernelResult
WriteReturns
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ShapeBase SubExp
rws)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
arr)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(Slice SubExp, SubExp)]
res)
substituteNames Map VName VName
subst (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
Certs -> [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp)]
dims)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
substituteNames Map VName VName
subst (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp, SubExp)]
dims_n_tiles)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
instance ASTRep rep => Rename (KernelBody rep) where
rename :: KernelBody rep -> RenameM (KernelBody rep)
rename (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) = do
BodyDec rep
dec' <- forall a. Rename a => a -> RenameM a
rename BodyDec rep
dec
forall rep a.
Renameable rep =>
Stms rep -> (Stms rep -> RenameM a) -> RenameM a
renamingStms Stms rep
stms forall a b. (a -> b) -> a -> b
$ \Stms rep
stms' ->
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec' Stms rep
stms' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename [KernelResult]
res
instance Rename KernelResult where
rename :: KernelResult -> RenameM KernelResult
rename = forall a. Substitute a => a -> RenameM a
substituteRename
aliasAnalyseKernelBody ::
Alias.AliasableRep rep =>
AliasTable ->
KernelBody rep ->
KernelBody (Aliases rep)
aliasAnalyseKernelBody :: forall rep.
AliasableRep rep =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
aliases (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
let Body BodyDec (Aliases rep)
dec' Stms (Aliases rep)
stms' Result
_ = forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
Alias.analyseBody AliasTable
aliases forall a b. (a -> b) -> a -> b
$ forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms []
in forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Aliases rep)
dec' Stms (Aliases rep)
stms' [KernelResult]
res
consumedInKernelBody ::
Aliased rep =>
KernelBody rep ->
Names
consumedInKernelBody :: forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
forall rep. Aliased rep => Body rep -> Names
consumedInBody (forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms []) forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Names
consumedByReturn [KernelResult]
res)
where
consumedByReturn :: KernelResult -> Names
consumedByReturn (WriteReturns Certs
_ ShapeBase SubExp
_ VName
a [(Slice SubExp, SubExp)]
_) = VName -> Names
oneName VName
a
consumedByReturn KernelResult
_ = forall a. Monoid a => a
mempty
checkKernelBody ::
TC.Checkable rep =>
[Type] ->
KernelBody (Aliases rep) ->
TC.TypeM rep ()
checkKernelBody :: forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts (KernelBody (BodyAliasing
_, BodyDec rep
dec) Stms (Aliases rep)
stms [KernelResult]
kres) = do
forall rep. Checkable rep => BodyDec rep -> TypeM rep ()
TC.checkBodyDec BodyDec rep
dec
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {rep}. Checkable rep => KernelResult -> TypeM rep ()
consumeKernelResult [KernelResult]
kres
forall rep a.
Checkable rep =>
Stms (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.checkStms Stms (Aliases rep)
stms forall a b. (a -> b) -> a -> b
$ do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres) forall a b. (a -> b) -> a -> b
$
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
Text
"Kernel return type is "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
ts
forall a. Semigroup a => a -> a -> a
<> Text
", but body returns "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText (forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres)
forall a. Semigroup a => a -> a -> a
<> Text
" values."
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {rep}. Checkable rep => KernelResult -> Type -> TypeM rep ()
checkKernelResult [KernelResult]
kres [Type]
ts
where
consumeKernelResult :: KernelResult -> TypeM rep ()
consumeKernelResult (WriteReturns Certs
_ ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) =
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
arr
consumeKernelResult KernelResult
_ =
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
checkKernelResult :: KernelResult -> Type -> TypeM rep ()
checkKernelResult (Returns ResultManifest
_ Certs
cs SubExp
what) Type
t = do
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [Type
t] SubExp
what
checkKernelResult (WriteReturns Certs
cs ShapeBase SubExp
shape VName
arr [(Slice SubExp, SubExp)]
res) Type
t = do
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
Type
arr_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Slice SubExp, SubExp)]
res forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
e) -> do
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Slice SubExp
slice
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [Type
t] SubExp
e
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t forall a. Eq a => a -> a -> Bool
== Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) forall a b. (a -> b) -> a -> b
$
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
Text
"WriteReturns returning "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText SubExp
e
forall a. Semigroup a => a -> a -> a
<> Text
" of type "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText Type
t
forall a. Semigroup a => a -> a -> a
<> Text
", shape="
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText ShapeBase SubExp
shape
forall a. Semigroup a => a -> a -> a
<> Text
", but destination array has type "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText Type
arr_t
checkKernelResult (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) Type
t = do
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(SubExp, SubExp)]
dims forall a b. (a -> b) -> a -> b
$ \(SubExp
dim, SubExp
tile) -> do
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
dim
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
tile
Type
vt <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt forall a. Eq a => a -> a -> Bool
== Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` forall d. [d] -> ShapeBase d
Shape (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims)) forall a b. (a -> b) -> a -> b
$
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
Text
"Invalid type for TileReturns " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText VName
v
checkKernelResult (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
arr) Type
t = do
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
dims
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
blk_tiles
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
reg_tiles
Type
arr_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t forall a. Eq a => a -> a -> Bool
== Type
expected) forall a b. (a -> b) -> a -> b
$
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
Text
"Invalid type for TileReturns. Expected:\n "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText Type
expected
forall a. Semigroup a => a -> a -> a
<> Text
",\ngot:\n "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText Type
arr_t
where
([SubExp]
dims, [SubExp]
blk_tiles, [SubExp]
reg_tiles) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, SubExp)]
dims_n_tiles
expected :: Type
expected = Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` forall d. [d] -> ShapeBase d
Shape ([SubExp]
blk_tiles forall a. Semigroup a => a -> a -> a
<> [SubExp]
reg_tiles)
kernelBodyMetrics :: OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics :: forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall rep. OpMetrics (Op rep) => Stm rep -> MetricsM ()
stmMetrics forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. KernelBody rep -> Stms rep
kernelBodyStms
instance PrettyRep rep => Pretty (KernelBody rep) where
pretty :: forall ann. KernelBody rep -> Doc ann
pretty (KernelBody BodyDec rep
_ Stms rep
stms [KernelResult]
res) =
forall a. [Doc a] -> Doc a
PP.stack (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty (forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms))
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"return"
forall a. Doc a -> Doc a -> Doc a
<+> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
PP.commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [KernelResult]
res)
certAnnots :: Certs -> [Doc ann]
certAnnots :: forall ann. Certs -> [Doc ann]
certAnnots Certs
cs
| Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty = []
| Bool
otherwise = [forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs]
instance Pretty KernelResult where
pretty :: forall ann. KernelResult -> Doc ann
pretty (Returns ResultManifest
ResultNoSimplify Certs
cs SubExp
what) =
forall a. [Doc a] -> Doc a
hsep forall a b. (a -> b) -> a -> b
$ forall ann. Certs -> [Doc ann]
certAnnots Certs
cs forall a. Semigroup a => a -> a -> a
<> [Doc ann
"returns (manifest)" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
what]
pretty (Returns ResultManifest
ResultPrivate Certs
cs SubExp
what) =
forall a. [Doc a] -> Doc a
hsep forall a b. (a -> b) -> a -> b
$ forall ann. Certs -> [Doc ann]
certAnnots Certs
cs forall a. Semigroup a => a -> a -> a
<> [Doc ann
"returns (private)" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
what]
pretty (Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
what) =
forall a. [Doc a] -> Doc a
hsep forall a b. (a -> b) -> a -> b
$ forall ann. Certs -> [Doc ann]
certAnnots Certs
cs forall a. Semigroup a => a -> a -> a
<> [Doc ann
"returns" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
what]
pretty (WriteReturns Certs
cs ShapeBase SubExp
shape VName
arr [(Slice SubExp, SubExp)]
res) =
forall a. [Doc a] -> Doc a
hsep forall a b. (a -> b) -> a -> b
$
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs
forall a. Semigroup a => a -> a -> a
<> [ forall a ann. Pretty a => a -> Doc ann
pretty VName
arr
forall a. Doc a -> Doc a -> Doc a
<+> forall ann. Doc ann
PP.colon
forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty ShapeBase SubExp
shape
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"with"
forall a. Doc a -> Doc a -> Doc a
<+> forall a. [Doc a] -> Doc a
PP.apply (forall a b. (a -> b) -> [a] -> [b]
map forall {a} {a} {ann}. (Pretty a, Pretty a) => (a, a) -> Doc ann
ppRes [(Slice SubExp, SubExp)]
res)
]
where
ppRes :: (a, a) -> Doc ann
ppRes (a
slice, a
e) = forall a ann. Pretty a => a -> Doc ann
pretty a
slice forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"=" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty a
e
pretty (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
forall a. [Doc a] -> Doc a
hsep forall a b. (a -> b) -> a -> b
$ forall ann. Certs -> [Doc ann]
certAnnots Certs
cs forall a. Semigroup a => a -> a -> a
<> [Doc ann
"tile" forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
apply (forall a b. (a -> b) -> [a] -> [b]
map forall {a} {a} {ann}. (Pretty a, Pretty a) => (a, a) -> Doc ann
onDim [(SubExp, SubExp)]
dims) forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty VName
v]
where
onDim :: (a, a) -> Doc ann
onDim (a
dim, a
tile) = forall a ann. Pretty a => a -> Doc ann
pretty a
dim forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"/" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty a
tile
pretty (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
forall a. [Doc a] -> Doc a
hsep forall a b. (a -> b) -> a -> b
$ forall ann. Certs -> [Doc ann]
certAnnots Certs
cs forall a. Semigroup a => a -> a -> a
<> [Doc ann
"blkreg_tile" forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
apply (forall a b. (a -> b) -> [a] -> [b]
map forall {a} {a} {a} {ann}.
(Pretty a, Pretty a, Pretty a) =>
(a, a, a) -> Doc ann
onDim [(SubExp, SubExp, SubExp)]
dims_n_tiles) forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty VName
v]
where
onDim :: (a, a, a) -> Doc ann
onDim (a
dim, a
blk_tile, a
reg_tile) =
forall a ann. Pretty a => a -> Doc ann
pretty a
dim forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"/" forall a. Doc a -> Doc a -> Doc a
<+> forall ann. Doc ann -> Doc ann
parens (forall a ann. Pretty a => a -> Doc ann
pretty a
blk_tile forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"*" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty a
reg_tile)
data SegSpace = SegSpace
{
SegSpace -> VName
segFlat :: VName,
SegSpace -> [(VName, SubExp)]
unSegSpace :: [(VName, SubExp)]
}
deriving (SegSpace -> SegSpace -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegSpace -> SegSpace -> Bool
$c/= :: SegSpace -> SegSpace -> Bool
== :: SegSpace -> SegSpace -> Bool
$c== :: SegSpace -> SegSpace -> Bool
Eq, Eq SegSpace
SegSpace -> SegSpace -> Bool
SegSpace -> SegSpace -> Ordering
SegSpace -> SegSpace -> SegSpace
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegSpace -> SegSpace -> SegSpace
$cmin :: SegSpace -> SegSpace -> SegSpace
max :: SegSpace -> SegSpace -> SegSpace
$cmax :: SegSpace -> SegSpace -> SegSpace
>= :: SegSpace -> SegSpace -> Bool
$c>= :: SegSpace -> SegSpace -> Bool
> :: SegSpace -> SegSpace -> Bool
$c> :: SegSpace -> SegSpace -> Bool
<= :: SegSpace -> SegSpace -> Bool
$c<= :: SegSpace -> SegSpace -> Bool
< :: SegSpace -> SegSpace -> Bool
$c< :: SegSpace -> SegSpace -> Bool
compare :: SegSpace -> SegSpace -> Ordering
$ccompare :: SegSpace -> SegSpace -> Ordering
Ord, Int -> SegSpace -> ShowS
[SegSpace] -> ShowS
SegSpace -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegSpace] -> ShowS
$cshowList :: [SegSpace] -> ShowS
show :: SegSpace -> String
$cshow :: SegSpace -> String
showsPrec :: Int -> SegSpace -> ShowS
$cshowsPrec :: Int -> SegSpace -> ShowS
Show)
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims (SegSpace VName
_ [(VName, SubExp)]
space) = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
space
scopeOfSegSpace :: SegSpace -> Scope rep
scopeOfSegSpace :: forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegSpace VName
phys [(VName, SubExp)]
space) =
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (VName
phys forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, SubExp)]
space) forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat forall a b. (a -> b) -> a -> b
$ forall rep. IntType -> NameInfo rep
IndexName IntType
Int64
checkSegSpace :: TC.Checkable rep => SegSpace -> TC.TypeM rep ()
checkSegSpace :: forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace (SegSpace VName
_ [(VName, SubExp)]
dims) =
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> b
snd) [(VName, SubExp)]
dims
data SegOp lvl rep
= SegMap lvl SegSpace [Type] (KernelBody rep)
|
SegRed lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep)
| SegScan lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep)
| SegHist lvl SegSpace [HistOp rep] [Type] (KernelBody rep)
deriving (SegOp lvl rep -> SegOp lvl rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
/= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c/= :: forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
== :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c== :: forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
Eq, SegOp lvl rep -> SegOp lvl rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {lvl} {rep}. (RepTypes rep, Ord lvl) => Eq (SegOp lvl rep)
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Ordering
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
min :: SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
$cmin :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
max :: SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
$cmax :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
>= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c>= :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
> :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c> :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
<= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c<= :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
< :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c< :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
compare :: SegOp lvl rep -> SegOp lvl rep -> Ordering
$ccompare :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Ordering
Ord, Int -> SegOp lvl rep -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall lvl rep.
(RepTypes rep, Show lvl) =>
Int -> SegOp lvl rep -> ShowS
forall lvl rep.
(RepTypes rep, Show lvl) =>
[SegOp lvl rep] -> ShowS
forall lvl rep. (RepTypes rep, Show lvl) => SegOp lvl rep -> String
showList :: [SegOp lvl rep] -> ShowS
$cshowList :: forall lvl rep.
(RepTypes rep, Show lvl) =>
[SegOp lvl rep] -> ShowS
show :: SegOp lvl rep -> String
$cshow :: forall lvl rep. (RepTypes rep, Show lvl) => SegOp lvl rep -> String
showsPrec :: Int -> SegOp lvl rep -> ShowS
$cshowsPrec :: forall lvl rep.
(RepTypes rep, Show lvl) =>
Int -> SegOp lvl rep -> ShowS
Show)
segLevel :: SegOp lvl rep -> lvl
segLevel :: forall lvl rep. SegOp lvl rep -> lvl
segLevel (SegMap lvl
lvl SegSpace
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegRed lvl
lvl SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegScan lvl
lvl SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegHist lvl
lvl SegSpace
_ [HistOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segSpace :: SegOp lvl rep -> SegSpace
segSpace :: forall lvl rep. SegOp lvl rep -> SegSpace
segSpace (SegMap lvl
_ SegSpace
lvl [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegRed lvl
_ SegSpace
lvl [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegScan lvl
_ SegSpace
lvl [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegHist lvl
_ SegSpace
lvl [HistOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segBody :: SegOp lvl rep -> KernelBody rep
segBody :: forall lvl rep. SegOp lvl rep -> KernelBody rep
segBody SegOp lvl rep
segop =
case SegOp lvl rep
segop of
SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
SegRed lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
SegScan lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
SegHist lvl
_ SegSpace
_ [HistOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
_ Type
t (WriteReturns Certs
_ ShapeBase SubExp
shape VName
_ [(Slice SubExp, SubExp)]
_) =
Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape
segResultShape SegSpace
space Type
t Returns {} =
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow) Type
t forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segResultShape SegSpace
_ Type
t (TileReturns Certs
_ [(SubExp, SubExp)]
dims VName
_) =
Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` forall d. [d] -> ShapeBase d
Shape (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims)
segResultShape SegSpace
_ Type
t (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
_) =
Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` forall d. [d] -> ShapeBase d
Shape (forall a b. (a -> b) -> [a] -> [b]
map (\(SubExp
dim, SubExp
_, SubExp
_) -> SubExp
dim) [(SubExp, SubExp, SubExp)]
dims_n_tiles)
segOpType :: SegOp lvl rep -> [Type]
segOpType :: forall lvl rep. SegOp lvl rep -> [Type]
segOpType (SegMap lvl
_ SegSpace
space [Type]
ts KernelBody rep
kbody) =
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
ts forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
segOpType (SegRed lvl
_ SegSpace
space [SegBinOp rep]
reds [Type]
ts KernelBody rep
kbody) =
[Type]
red_ts
forall a. [a] -> [a] -> [a]
++ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
[Type]
map_ts
(forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody)
where
map_ts :: [Type]
map_ts = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) [Type]
ts
segment_dims :: [SubExp]
segment_dims = forall a. [a] -> [a]
init forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
red_ts :: [Type]
red_ts = do
SegBinOp rep
op <- [SegBinOp rep]
reds
let shape :: ShapeBase SubExp
shape = forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims forall a. Semigroup a => a -> a -> a
<> forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (forall rep. Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op)
segOpType (SegScan lvl
_ SegSpace
space [SegBinOp rep]
scans [Type]
ts KernelBody rep
kbody) =
[Type]
scan_ts
forall a. [a] -> [a] -> [a]
++ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
[Type]
map_ts
(forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody)
where
map_ts :: [Type]
map_ts = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) [Type]
ts
scan_ts :: [Type]
scan_ts = do
SegBinOp rep
op <- [SegBinOp rep]
scans
let shape :: ShapeBase SubExp
shape = forall d. [d] -> ShapeBase d
Shape (SegSpace -> [SubExp]
segSpaceDims SegSpace
space) forall a. Semigroup a => a -> a -> a
<> forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (forall rep. Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op)
segOpType (SegHist lvl
_ SegSpace
space [HistOp rep]
ops [Type]
_ KernelBody rep
_) = do
HistOp rep
op <- [HistOp rep]
ops
let shape :: ShapeBase SubExp
shape = forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims forall a. Semigroup a => a -> a -> a
<> forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp rep
op forall a. Semigroup a => a -> a -> a
<> forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp rep
op
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (forall rep. Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op)
where
dims :: [SubExp]
dims = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segment_dims :: [SubExp]
segment_dims = forall a. [a] -> [a]
init [SubExp]
dims
instance TypedOp (SegOp lvl rep) where
opType :: forall t (m :: * -> *).
HasScope t m =>
SegOp lvl rep -> m [ExtType]
opType = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall u. [TypeBase (ShapeBase SubExp) u] -> [TypeBase ExtShape u]
staticShapes forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall lvl rep. SegOp lvl rep -> [Type]
segOpType
instance (ASTConstraints lvl, Aliased rep) => AliasedOp (SegOp lvl rep) where
opAliases :: SegOp lvl rep -> [Names]
opAliases = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const forall a. Monoid a => a
mempty) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall lvl rep. SegOp lvl rep -> [Type]
segOpType
consumedInOp :: SegOp lvl rep -> Names
consumedInOp (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
kbody) =
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
consumedInOp (SegRed lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
consumedInOp (SegScan lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
consumedInOp (SegHist lvl
_ SegSpace
_ [HistOp rep]
ops [Type]
_ KernelBody rep
kbody) =
[VName] -> Names
namesFromList (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall rep. HistOp rep -> [VName]
histDest [HistOp rep]
ops) forall a. Semigroup a => a -> a -> a
<> forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
typeCheckSegOp ::
TC.Checkable rep =>
(lvl -> TC.TypeM rep ()) ->
SegOp lvl (Aliases rep) ->
TC.TypeM rep ()
typeCheckSegOp :: forall rep lvl.
Checkable rep =>
(lvl -> TypeM rep ()) -> SegOp lvl (Aliases rep) -> TypeM rep ()
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Aliases rep)
kbody) = do
lvl -> TypeM rep ()
checkLvl lvl
lvl
forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [] [Type]
ts KernelBody (Aliases rep)
kbody
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegRed lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
reds [Type]
ts KernelBody (Aliases rep)
body) = do
lvl -> TypeM rep ()
checkLvl lvl
lvl
forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
reds' [Type]
ts KernelBody (Aliases rep)
body
where
reds' :: [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
reds' =
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
(forall a b. (a -> b) -> [a] -> [b]
map forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp (Aliases rep)]
reds)
(forall a b. (a -> b) -> [a] -> [b]
map forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases rep)]
reds)
(forall a b. (a -> b) -> [a] -> [b]
map forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape [SegBinOp (Aliases rep)]
reds)
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegScan lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
scans [Type]
ts KernelBody (Aliases rep)
body) = do
lvl -> TypeM rep ()
checkLvl lvl
lvl
forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
scans' [Type]
ts KernelBody (Aliases rep)
body
where
scans' :: [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
scans' =
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
(forall a b. (a -> b) -> [a] -> [b]
map forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp (Aliases rep)]
scans)
(forall a b. (a -> b) -> [a] -> [b]
map forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases rep)]
scans)
(forall a b. (a -> b) -> [a] -> [b]
map forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape [SegBinOp (Aliases rep)]
scans)
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegHist lvl
lvl SegSpace
space [HistOp (Aliases rep)]
ops [Type]
ts KernelBody (Aliases rep)
kbody) = do
lvl -> TypeM rep ()
checkLvl lvl
lvl
forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace SegSpace
space
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall rep u.
Checkable rep =>
TypeBase (ShapeBase SubExp) u -> TypeM rep ()
TC.checkType [Type]
ts
forall rep a.
Checkable rep =>
Scope (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.binding (forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$ do
[[Type]]
nes_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Aliases rep)]
ops forall a b. (a -> b) -> a -> b
$ \(HistOp ShapeBase SubExp
dest_shape SubExp
rf [VName]
dests [SubExp]
nes ShapeBase SubExp
shape Lambda (Aliases rep)
op) -> do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ShapeBase SubExp
dest_shape
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
rf
[Arg]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
let stripVecDims :: Type -> Type
stripVecDims = forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray forall a b. (a -> b) -> a -> b
$ forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shape
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
op forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Arg -> Arg
TC.noArgAliases forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Type -> Type
stripVecDims) forall a b. (a -> b) -> a -> b
$ [Arg]
nes' forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
let nes_t :: [Type]
nes_t = forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t forall a. Eq a => a -> a -> Bool
== forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op) forall a b. (a -> b) -> a -> b
$
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
Text
"SegHist operator has return type "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op)
forall a. Semigroup a => a -> a -> a
<> Text
" but neutral element has type "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
nes_t
let dest_shape' :: ShapeBase SubExp
dest_shape' = forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
dest_shape forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
shape
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
forall rep. Checkable rep => [Type] -> VName -> TypeM rep ()
TC.requireI [Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
dest_shape'] VName
dest
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
dest
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) [Type]
nes_t
forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts KernelBody (Aliases rep)
kbody
let bucket_ret_t :: [Type]
bucket_ret_t =
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((forall a. Int -> a -> [a]
`replicate` forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. ArrayShape a => a -> Int
shapeRank forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. HistOp rep -> ShapeBase SubExp
histShape) [HistOp (Aliases rep)]
ops
forall a. [a] -> [a] -> [a]
++ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
nes_ts
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t forall a. Eq a => a -> a -> Bool
== [Type]
ts) forall a b. (a -> b) -> a -> b
$
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
Text
"SegHist body has return type "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
ts
forall a. Semigroup a => a -> a -> a
<> Text
" but should have type "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
bucket_ret_t
where
segment_dims :: [SubExp]
segment_dims = forall a. [a] -> [a]
init forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
checkScanRed ::
TC.Checkable rep =>
SegSpace ->
[(Lambda (Aliases rep), [SubExp], Shape)] ->
[Type] ->
KernelBody (Aliases rep) ->
TC.TypeM rep ()
checkScanRed :: forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
ops [Type]
ts KernelBody (Aliases rep)
kbody = do
forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace SegSpace
space
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall rep u.
Checkable rep =>
TypeBase (ShapeBase SubExp) u -> TypeM rep ()
TC.checkType [Type]
ts
forall rep a.
Checkable rep =>
Scope (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.binding (forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$ do
[[Type]]
ne_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
ops forall a b. (a -> b) -> a -> b
$ \(Lambda (Aliases rep)
lam, [SubExp]
nes, ShapeBase SubExp
shape) -> do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
[Arg]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases forall a b. (a -> b) -> a -> b
$ [Arg]
nes' forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
let nes_t :: [Type]
nes_t = forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam forall a. Eq a => a -> a -> Bool
== [Type]
nes_t) forall a b. (a -> b) -> a -> b
$
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
forall rep. Text -> ErrorCase rep
TC.TypeError Text
"wrong type for operator or neutral elements."
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) [Type]
nes_t
let expecting :: [Type]
expecting = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
ne_ts
got :: [Type]
got = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
expecting) [Type]
ts
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
expecting forall a. Eq a => a -> a -> Bool
== [Type]
got) forall a b. (a -> b) -> a -> b
$
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
Text
"Wrong return for body (does not match neutral elements; expected "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText [Type]
expecting
forall a. Semigroup a => a -> a -> a
<> Text
"; found "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText [Type]
got
forall a. Semigroup a => a -> a -> a
<> Text
")"
forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts KernelBody (Aliases rep)
kbody
data SegOpMapper lvl frep trep m = SegOpMapper
{ forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp :: SubExp -> m SubExp,
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda :: Lambda frep -> m (Lambda trep),
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody :: KernelBody frep -> m (KernelBody trep),
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName :: VName -> m VName,
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel :: lvl -> m lvl
}
identitySegOpMapper :: Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper :: forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper =
SegOpMapper
{ mapOnSegOpSubExp :: SubExp -> m SubExp
mapOnSegOpSubExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnSegOpLambda :: Lambda rep -> m (Lambda rep)
mapOnSegOpLambda = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnSegOpBody :: KernelBody rep -> m (KernelBody rep)
mapOnSegOpBody = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnSegOpVName :: VName -> m VName
mapOnSegOpVName = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnSegOpLevel :: lvl -> m lvl
mapOnSegOpLevel = forall (f :: * -> *) a. Applicative f => a -> f a
pure
}
mapOnSegSpace ::
Monad f => SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace :: forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep f
tv (SegSpace VName
phys [(VName, SubExp)]
dims) =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep f
tv VName
phys
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep f
tv) (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep f
tv)) [(VName, SubExp)]
dims
mapSegBinOp ::
Monad m =>
SegOpMapper lvl frep trep m ->
SegBinOp frep ->
m (SegBinOp trep)
mapSegBinOp :: forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv (SegBinOp Commutativity
comm Lambda frep
red_op [SubExp]
nes ShapeBase SubExp
shape) =
forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda SegOpMapper lvl frep trep m
tv Lambda frep
red_op
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [SubExp]
nes
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall d. [d] -> ShapeBase d
Shape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) (forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape))
mapSegOpM ::
Monad m =>
SegOpMapper lvl frep trep m ->
SegOp lvl frep ->
m (SegOp lvl trep)
mapSegOpM :: forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl frep trep m
tv (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody frep
body) =
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl frep trep m
tv) [Type]
ts
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
mapSegOpM SegOpMapper lvl frep trep m
tv (SegRed lvl
lvl SegSpace
space [SegBinOp frep]
reds [Type]
ts KernelBody frep
lam) =
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv) [SegBinOp frep]
reds
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType forall a b. (a -> b) -> a -> b
$ forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
lam
mapSegOpM SegOpMapper lvl frep trep m
tv (SegScan lvl
lvl SegSpace
space [SegBinOp frep]
scans [Type]
ts KernelBody frep
body) =
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv) [SegBinOp frep]
scans
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType forall a b. (a -> b) -> a -> b
$ forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
mapSegOpM SegOpMapper lvl frep trep m
tv (SegHist lvl
lvl SegSpace
space [HistOp frep]
ops [Type]
ts KernelBody frep
body) =
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp frep -> m (HistOp trep)
onHistOp [HistOp frep]
ops
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType forall a b. (a -> b) -> a -> b
$ forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
where
onHistOp :: HistOp frep -> m (HistOp trep)
onHistOp (HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
shape Lambda frep
op) =
forall rep.
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda rep
-> HistOp rep
HistOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) ShapeBase SubExp
w
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv SubExp
rf
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep m
tv) [VName]
arrs
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [SubExp]
nes
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall d. [d] -> ShapeBase d
Shape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) (forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda SegOpMapper lvl frep trep m
tv Lambda frep
op
mapOnSegOpType ::
Monad m =>
SegOpMapper lvl frep trep m ->
Type ->
m Type
mapOnSegOpType :: forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl frep trep m
_tv t :: Type
t@Prim {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t
mapOnSegOpType SegOpMapper lvl frep trep m
tv (Acc VName
acc ShapeBase SubExp
ispace [Type]
ts NoUniqueness
u) =
forall shape u.
VName -> ShapeBase SubExp -> [Type] -> u -> TypeBase shape u
Acc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep m
tv VName
acc
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) ShapeBase SubExp
ispace
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv)) forall (f :: * -> *) a. Applicative f => a -> f a
pure) [Type]
ts
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
mapOnSegOpType SegOpMapper lvl frep trep m
tv (Array PrimType
et ShapeBase SubExp
shape NoUniqueness
u) =
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) ShapeBase SubExp
shape forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
mapOnSegOpType SegOpMapper lvl frep trep m
_tv (Mem Space
s) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall shape u. Space -> TypeBase shape u
Mem Space
s
rephraseBinOp ::
Monad f =>
Rephraser f from rep ->
SegBinOp from ->
f (SegBinOp rep)
rephraseBinOp :: forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> SegBinOp from -> f (SegBinOp rep)
rephraseBinOp Rephraser f from rep
r (SegBinOp Commutativity
comm Lambda from
lam [SubExp]
nes ShapeBase SubExp
shape) =
forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser f from rep
r Lambda from
lam forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure ShapeBase SubExp
shape
rephraseKernelBody ::
Monad f =>
Rephraser f from rep ->
KernelBody from ->
f (KernelBody rep)
rephraseKernelBody :: forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser f from rep
r (KernelBody BodyDec from
dec Stms from
stms [KernelResult]
res) =
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) from to.
Rephraser m from to -> BodyDec from -> m (BodyDec to)
rephraseBodyDec Rephraser f from rep
r BodyDec from
dec forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm Rephraser f from rep
r) Stms from
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res
instance RephraseOp (SegOp lvl) where
rephraseInOp :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> SegOp lvl from -> m (SegOp lvl to)
rephraseInOp Rephraser m from to
r (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody from
body) =
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
space [Type]
ts forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
rephraseInOp Rephraser m from to
r (SegRed lvl
lvl SegSpace
space [SegBinOp from]
reds [Type]
ts KernelBody from
body) =
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> SegBinOp from -> f (SegBinOp rep)
rephraseBinOp Rephraser m from to
r) [SegBinOp from]
reds
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ts
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
rephraseInOp Rephraser m from to
r (SegScan lvl
lvl SegSpace
space [SegBinOp from]
scans [Type]
ts KernelBody from
body) =
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> SegBinOp from -> f (SegBinOp rep)
rephraseBinOp Rephraser m from to
r) [SegBinOp from]
scans
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ts
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
rephraseInOp Rephraser m from to
r (SegHist lvl
lvl SegSpace
space [HistOp from]
hists [Type]
ts KernelBody from
body) =
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp from -> m (HistOp to)
onOp [HistOp from]
hists
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ts
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
where
onOp :: HistOp from -> m (HistOp to)
onOp (HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
shape Lambda from
op) =
forall rep.
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda rep
-> HistOp rep
HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
shape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
op
traverseSegOpStms :: Monad m => OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms :: forall (m :: * -> *) lvl rep.
Monad m =>
OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms Scope rep -> Stms rep -> m (Stms rep)
f SegOp lvl rep
segop = forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep m
mapper SegOp lvl rep
segop
where
seg_scope :: Scope rep
seg_scope = forall rep. SegSpace -> Scope rep
scopeOfSegSpace (forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
segop)
f' :: Scope rep -> Stms rep -> m (Stms rep)
f' Scope rep
scope = Scope rep -> Stms rep -> m (Stms rep)
f (Scope rep
seg_scope forall a. Semigroup a => a -> a -> a
<> Scope rep
scope)
mapper :: SegOpMapper lvl rep rep m
mapper =
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
{ mapOnSegOpLambda :: Lambda rep -> m (Lambda rep)
mapOnSegOpLambda = forall (m :: * -> *) rep.
Monad m =>
OpStmsTraverser m (Lambda rep) rep
traverseLambdaStms Scope rep -> Stms rep -> m (Stms rep)
f',
mapOnSegOpBody :: KernelBody rep -> m (KernelBody rep)
mapOnSegOpBody = KernelBody rep -> m (KernelBody rep)
onBody
}
onBody :: KernelBody rep -> m (KernelBody rep)
onBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope rep -> Stms rep -> m (Stms rep)
f Scope rep
seg_scope Stms rep
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res
instance
(ASTRep rep, Substitute lvl) =>
Substitute (SegOp lvl rep)
where
substituteNames :: Map VName VName -> SegOp lvl rep -> SegOp lvl rep
substituteNames Map VName VName
subst = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep Identity
substitute
where
substitute :: SegOpMapper lvl rep rep Identity
substitute =
SegOpMapper
{ mapOnSegOpSubExp :: SubExp -> Identity SubExp
mapOnSegOpSubExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpLambda :: Lambda rep -> Identity (Lambda rep)
mapOnSegOpLambda = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpBody :: KernelBody rep -> Identity (KernelBody rep)
mapOnSegOpBody = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpVName :: VName -> Identity VName
mapOnSegOpVName = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpLevel :: lvl -> Identity lvl
mapOnSegOpLevel = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
}
instance (ASTRep rep, ASTConstraints lvl) => Rename (SegOp lvl rep) where
rename :: SegOp lvl rep -> RenameM (SegOp lvl rep)
rename SegOp lvl rep
op =
forall a. [VName] -> RenameM a -> RenameM a
renameBound (forall k a. Map k a -> [k]
M.keys (forall rep. SegSpace -> Scope rep
scopeOfSegSpace (forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
op))) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep RenameM
renamer SegOp lvl rep
op
where
renamer :: SegOpMapper lvl rep rep RenameM
renamer = forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper forall a. Rename a => a -> RenameM a
rename forall a. Rename a => a -> RenameM a
rename forall a. Rename a => a -> RenameM a
rename forall a. Rename a => a -> RenameM a
rename forall a. Rename a => a -> RenameM a
rename
instance (ASTRep rep, FreeIn lvl) => FreeIn (SegOp lvl rep) where
freeIn' :: SegOp lvl rep -> FV
freeIn' SegOp lvl rep
e =
Names -> FV -> FV
fvBind ([VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ forall rep. SegSpace -> Scope rep
scopeOfSegSpace (forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
e)) forall a b. (a -> b) -> a -> b
$
forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s a. State s a -> s -> s
execState forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep (StateT FV Identity)
free SegOp lvl rep
e
where
walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Semigroup a => a -> a -> a
<> b -> s
f b
x) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure b
x
free :: SegOpMapper lvl rep rep (StateT FV Identity)
free =
SegOpMapper
{ mapOnSegOpSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSegOpSubExp = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpLambda :: Lambda rep -> StateT FV Identity (Lambda rep)
mapOnSegOpLambda = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpBody :: KernelBody rep -> StateT FV Identity (KernelBody rep)
mapOnSegOpBody = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpVName :: VName -> StateT FV Identity VName
mapOnSegOpVName = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpLevel :: lvl -> StateT FV Identity lvl
mapOnSegOpLevel = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn'
}
instance OpMetrics (Op rep) => OpMetrics (SegOp lvl rep) where
opMetrics :: SegOp lvl rep -> MetricsM ()
opMetrics (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegMap" forall a b. (a -> b) -> a -> b
$ forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
opMetrics (SegRed lvl
_ SegSpace
_ [SegBinOp rep]
reds [Type]
_ KernelBody rep
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegRed" forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp rep]
reds
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
opMetrics (SegScan lvl
_ SegSpace
_ [SegBinOp rep]
scans [Type]
_ KernelBody rep
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegScan" forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp rep]
scans
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
opMetrics (SegHist lvl
_ SegSpace
_ [HistOp rep]
ops [Type]
_ KernelBody rep
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegHist" forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. HistOp rep -> Lambda rep
histOp) [HistOp rep]
ops
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
instance Pretty SegSpace where
pretty :: forall ann. SegSpace -> Doc ann
pretty (SegSpace VName
phys [(VName, SubExp)]
dims) =
forall a. [Doc a] -> Doc a
apply
( do
(VName
i, SubExp
d) <- [(VName, SubExp)]
dims
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a ann. Pretty a => a -> Doc ann
pretty VName
i forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"<" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
d
)
forall a. Doc a -> Doc a -> Doc a
<+> forall ann. Doc ann -> Doc ann
parens (Doc ann
"~" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty VName
phys)
instance PrettyRep rep => Pretty (SegBinOp rep) where
pretty :: forall ann. SegBinOp rep -> Doc ann
pretty (SegBinOp Commutativity
comm Lambda rep
lam [SubExp]
nes ShapeBase SubExp
shape) =
forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
PP.commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
nes) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty ShapeBase SubExp
shape forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
comm' forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
lam
where
comm' :: Doc ann
comm' = case Commutativity
comm of
Commutativity
Commutative -> Doc ann
"commutative "
Commutativity
Noncommutative -> forall a. Monoid a => a
mempty
instance (PrettyRep rep, PP.Pretty lvl) => PP.Pretty (SegOp lvl rep) where
pretty :: forall ann. SegOp lvl rep -> Doc ann
pretty (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody rep
body) =
Doc ann
"segmap" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.align (forall a ann. Pretty a => a -> Doc ann
pretty SegSpace
space)
forall a. Doc a -> Doc a -> Doc a
<+> forall ann. Doc ann
PP.colon
forall a. Doc a -> Doc a -> Doc a
<+> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Type]
ts)
forall a. Doc a -> Doc a -> Doc a
<+> forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (forall a ann. Pretty a => a -> Doc ann
pretty KernelBody rep
body)
pretty (SegRed lvl
lvl SegSpace
space [SegBinOp rep]
reds [Type]
ts KernelBody rep
body) =
Doc ann
"segred" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.align (forall a ann. Pretty a => a -> Doc ann
pretty SegSpace
space)
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.parens (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
PP.comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SegBinOp rep]
reds)
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann
PP.colon
forall a. Doc a -> Doc a -> Doc a
<+> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Type]
ts)
forall a. Doc a -> Doc a -> Doc a
<+> forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (forall a ann. Pretty a => a -> Doc ann
pretty KernelBody rep
body)
pretty (SegScan lvl
lvl SegSpace
space [SegBinOp rep]
scans [Type]
ts KernelBody rep
body) =
Doc ann
"segscan" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.align (forall a ann. Pretty a => a -> Doc ann
pretty SegSpace
space)
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.parens (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
PP.comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SegBinOp rep]
scans)
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann
PP.colon
forall a. Doc a -> Doc a -> Doc a
<+> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Type]
ts)
forall a. Doc a -> Doc a -> Doc a
<+> forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (forall a ann. Pretty a => a -> Doc ann
pretty KernelBody rep
body)
pretty (SegHist lvl
lvl SegSpace
space [HistOp rep]
ops [Type]
ts KernelBody rep
body) =
Doc ann
"seghist" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.align (forall a ann. Pretty a => a -> Doc ann
pretty SegSpace
space)
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.parens (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
PP.comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {rep} {ann}. PrettyRep rep => HistOp rep -> Doc ann
ppOp [HistOp rep]
ops)
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann
PP.colon
forall a. Doc a -> Doc a -> Doc a
<+> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Type]
ts)
forall a. Doc a -> Doc a -> Doc a
<+> forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (forall a ann. Pretty a => a -> Doc ann
pretty KernelBody rep
body)
where
ppOp :: HistOp rep -> Doc ann
ppOp (HistOp ShapeBase SubExp
w SubExp
rf [VName]
dests [SubExp]
nes ShapeBase SubExp
shape Lambda rep
op) =
forall a ann. Pretty a => a -> Doc ann
pretty ShapeBase SubExp
w forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
rf forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
PP.commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [VName]
dests) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
PP.commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
nes) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty ShapeBase SubExp
shape forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
op
instance CanBeAliased (SegOp lvl) where
addOpAliases :: forall rep.
AliasableRep rep =>
AliasTable -> SegOp lvl rep -> SegOp lvl (Aliases rep)
addOpAliases AliasTable
aliases = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep (Aliases rep) Identity
alias
where
alias :: SegOpMapper lvl rep (Aliases rep) Identity
alias =
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
forall (f :: * -> *) a. Applicative f => a -> f a
pure
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases)
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep.
AliasableRep rep =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
aliases)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
forall (f :: * -> *) a. Applicative f => a -> f a
pure
informKernelBody :: Informing rep => KernelBody rep -> KernelBody (Wise rep)
informKernelBody :: forall rep.
Informing rep =>
KernelBody rep -> KernelBody (Wise rep)
informKernelBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
forall rep.
Informing rep =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody BodyDec rep
dec (forall rep. Informing rep => Stms rep -> Stms (Wise rep)
informStms Stms rep
stms) [KernelResult]
res
instance CanBeWise (SegOp lvl) where
addOpWisdom :: forall rep. Informing rep => SegOp lvl rep -> SegOp lvl (Wise rep)
addOpWisdom = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM forall {lvl}. SegOpMapper lvl rep (Wise rep) Identity
add
where
add :: SegOpMapper lvl rep (Wise rep) Identity
add =
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
forall (f :: * -> *) a. Applicative f => a -> f a
pure
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Informing rep => Lambda rep -> Lambda (Wise rep)
informLambda)
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep.
Informing rep =>
KernelBody rep -> KernelBody (Wise rep)
informKernelBody)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
forall (f :: * -> *) a. Applicative f => a -> f a
pure
instance ASTRep rep => ST.IndexOp (SegOp lvl rep) where
indexOp :: forall rep.
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> SegOp lvl rep -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k (SegMap lvl
_ SegSpace
space [Type]
_ KernelBody rep
kbody) [TPrimExp Int64 VName]
is = do
Returns ResultManifest
ResultMaySimplify Certs
_ SubExp
se <- forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
k forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids forall a. Ord a => a -> a -> Bool
<= forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
is
let idx_table :: Map VName Indexed
idx_table = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Certs -> PrimExp VName -> Indexed
ST.Indexed forall a. Monoid a => a
mempty forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) [TPrimExp Int64 VName]
is
idx_table' :: Map VName Indexed
idx_table' = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map VName Indexed -> Stm rep -> Map VName Indexed
expandIndexedTable Map VName Indexed
idx_table forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
kbody
case SubExp
se of
Var VName
v -> forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
idx_table'
SubExp
_ -> forall a. Maybe a
Nothing
where
([VName]
gtids, [SubExp]
_) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
excess_is :: [TPrimExp Int64 VName]
excess_is = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) [TPrimExp Int64 VName]
is
expandIndexedTable :: Map VName Indexed -> Stm rep -> Map VName Indexed
expandIndexedTable Map VName Indexed
table Stm rep
stm
| [VName
v] <- forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
Just (PrimExp VName
pe, Certs
cs) <-
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) rep v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table) forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Certs -> PrimExp VName -> Indexed
ST.Indexed (forall rep. Stm rep -> Certs
stmCerts Stm rep
stm forall a. Semigroup a => a -> a -> a
<> Certs
cs) PrimExp VName
pe) Map VName Indexed
table
| [VName
v] <- forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
BasicOp (Index VName
arr Slice SubExp
slice) <- forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
excess_is,
VName
arr forall rep. VName -> SymbolTable rep -> Bool
`ST.available` SymbolTable rep
vtable,
Just (Slice (PrimExp VName)
slice', Certs
cs) <- Map VName Indexed
-> Slice SubExp -> Maybe (Slice (PrimExp VName), Certs)
asPrimExpSlice Map VName Indexed
table Slice SubExp
slice =
let idx :: Indexed
idx =
Certs -> VName -> [TPrimExp Int64 VName] -> Indexed
ST.IndexedArray
(forall rep. Stm rep -> Certs
stmCerts Stm rep
stm forall a. Semigroup a => a -> a -> a
<> Certs
cs)
VName
arr
(forall d. Num d => Slice d -> [d] -> [d]
fixSlice (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall v. PrimExp v -> TPrimExp Int64 v
isInt64 Slice (PrimExp VName)
slice') [TPrimExp Int64 VName]
excess_is)
in forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Indexed
idx Map VName Indexed
table
| Bool
otherwise =
Map VName Indexed
table
asPrimExpSlice :: Map VName Indexed
-> Slice SubExp -> Maybe (Slice (PrimExp VName), Certs)
asPrimExpSlice Map VName Indexed
table =
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM (Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table))
asPrimExp :: Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table VName
v
| Just (ST.Indexed Certs
cs PrimExp VName
e) <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
table = forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certs
cs forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimExp VName
e
| Just (Prim PrimType
pt) <- forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
| Bool
otherwise = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a. Maybe a
Nothing
indexOp SymbolTable rep
_ Int
_ SegOp lvl rep
_ [TPrimExp Int64 VName]
_ = forall a. Maybe a
Nothing
instance
(ASTRep rep, ASTConstraints lvl) =>
IsOp (SegOp lvl rep)
where
cheapOp :: SegOp lvl rep -> Bool
cheapOp SegOp lvl rep
_ = Bool
False
safeOp :: SegOp lvl rep -> Bool
safeOp SegOp lvl rep
_ = Bool
True
instance Engine.Simplifiable SegSpace where
simplify :: forall rep. SimplifiableRep rep => SegSpace -> SimpleM rep SegSpace
simplify (SegSpace VName
phys [(VName, SubExp)]
dims) =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
phys forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify) [(VName, SubExp)]
dims
instance Engine.Simplifiable KernelResult where
simplify :: forall rep.
SimplifiableRep rep =>
KernelResult -> SimpleM rep KernelResult
simplify (Returns ResultManifest
manifest Certs
cs SubExp
what) =
ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
what
simplify (WriteReturns Certs
cs ShapeBase SubExp
ws VName
a [(Slice SubExp, SubExp)]
res) =
Certs
-> ShapeBase SubExp
-> VName
-> [(Slice SubExp, SubExp)]
-> KernelResult
WriteReturns
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
ws
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
a
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(Slice SubExp, SubExp)]
res
simplify (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
what) =
Certs -> [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(SubExp, SubExp)]
dims forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
what
simplify (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
what) =
Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(SubExp, SubExp, SubExp)]
dims_n_tiles
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
what
mkWiseKernelBody ::
Informing rep =>
BodyDec rep ->
Stms (Wise rep) ->
[KernelResult] ->
KernelBody (Wise rep)
mkWiseKernelBody :: forall rep.
Informing rep =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody BodyDec rep
dec Stms (Wise rep)
stms [KernelResult]
res =
let Body BodyDec (Wise rep)
dec' Stms (Wise rep)
_ Result
_ = forall rep.
Informing rep =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
mkWiseBody BodyDec rep
dec Stms (Wise rep)
stms forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
res_vs
in forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Wise rep)
dec' Stms (Wise rep)
stms [KernelResult]
res
where
res_vs :: [SubExp]
res_vs = forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
res
mkKernelBodyM ::
MonadBuilder m =>
Stms (Rep m) ->
[KernelResult] ->
m (KernelBody (Rep m))
mkKernelBodyM :: forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms (Rep m)
stms [KernelResult]
kres = do
Body BodyDec (Rep m)
dec' Stms (Rep m)
_ Result
_ <- forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
stms forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
res_ses
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Rep m)
dec' Stms (Rep m)
stms [KernelResult]
kres
where
res_ses :: [SubExp]
res_ses = forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
kres
simplifyKernelBody ::
(Engine.SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace ->
KernelBody (Wise rep) ->
Engine.SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody :: forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space (KernelBody BodyDec (Wise rep)
_ Stms (Wise rep)
stms [KernelResult]
res) = do
BlockPred (Wise rep)
par_blocker <- forall rep a. (Env rep -> a) -> SimpleM rep a
Engine.asksEngineEnv forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). HoistBlockers rep -> BlockPred (Wise rep)
Engine.blockHoistPar forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Env rep -> HoistBlockers rep
Engine.envHoistBlockers
let blocker :: BlockPred (Wise rep)
blocker =
forall rep. ASTRep rep => Names -> BlockPred rep
Engine.hasFree Names
bound_here
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` forall rep. BlockPred rep
Engine.isOp
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
par_blocker
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` forall rep. BlockPred rep
Engine.isConsumed
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` forall rep. Aliased rep => BlockPred rep
Engine.isConsuming
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` forall rep. SimplifiableRep rep => BlockPred (Wise rep)
Engine.isDeviceMigrated
([KernelResult]
body_res, Stms (Wise rep)
body_stms, Stms (Wise rep)
hoisted) <-
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (forall a b c. (a -> b -> c) -> b -> a -> c
flip (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall rep. VName -> SymbolTable rep -> SymbolTable rep
ST.consume)) (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap KernelResult -> [VName]
consumedInResult [KernelResult]
res))
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True})
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop
forall a b. (a -> b) -> a -> b
$ forall rep a.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep (a, UsageTable)
-> SimpleM rep (a, Stms (Wise rep), Stms (Wise rep))
Engine.blockIf BlockPred (Wise rep)
blocker Stms (Wise rep)
stms
forall a b. (a -> b) -> a -> b
$ do
[KernelResult]
res' <-
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (forall rep. Names -> SymbolTable rep -> SymbolTable rep
ST.hideCertified forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms (Wise rep)
stms) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [KernelResult]
res
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([KernelResult]
res', Names -> UsageTable
UT.usages forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn [KernelResult]
res')
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep.
Informing rep =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody () Stms (Wise rep)
body_stms [KernelResult]
body_res, Stms (Wise rep)
hoisted)
where
scope_vtable :: SymbolTable (Wise rep)
scope_vtable = forall rep. ASTRep rep => SegSpace -> SymbolTable rep
segSpaceSymbolTable SegSpace
space
bound_here :: Names
bound_here = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
consumedInResult :: KernelResult -> [VName]
consumedInResult (WriteReturns Certs
_ ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) =
[VName
arr]
consumedInResult KernelResult
_ =
[]
simplifyLambda ::
Engine.SimplifiableRep rep =>
Names ->
Lambda (Wise rep) ->
Engine.SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda :: forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda Names
bound = forall rep.
SimplifiableRep rep =>
SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.blockMigrated forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
bound
segSpaceSymbolTable :: ASTRep rep => SegSpace -> ST.SymbolTable rep
segSpaceSymbolTable :: forall rep. ASTRep rep => SegSpace -> SymbolTable rep
segSpaceSymbolTable (SegSpace VName
flat [(VName, SubExp)]
gtids_and_dims) =
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {rep}.
ASTRep rep =>
SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
f (forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope forall a b. (a -> b) -> a -> b
$ forall k a. k -> a -> Map k a
M.singleton VName
flat forall a b. (a -> b) -> a -> b
$ forall rep. IntType -> NameInfo rep
IndexName IntType
Int64) [(VName, SubExp)]
gtids_and_dims
where
f :: SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
f SymbolTable rep
vtable (VName
gtid, SubExp
dim) = forall rep.
ASTRep rep =>
VName -> IntType -> SubExp -> SymbolTable rep -> SymbolTable rep
ST.insertLoopVar VName
gtid IntType
Int64 SubExp
dim SymbolTable rep
vtable
simplifySegBinOp ::
Engine.SimplifiableRep rep =>
VName ->
SegBinOp (Wise rep) ->
Engine.SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp :: forall rep.
SimplifiableRep rep =>
VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp VName
phys_id (SegBinOp Commutativity
comm Lambda (Wise rep)
lam [SubExp]
nes ShapeBase SubExp
shape) = do
(Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <-
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True}) forall a b. (a -> b) -> a -> b
$
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda (VName -> Names
oneName VName
phys_id) Lambda (Wise rep)
lam
ShapeBase SubExp
shape' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
shape
[SubExp]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm Lambda (Wise rep)
lam' [SubExp]
nes' ShapeBase SubExp
shape', Stms (Wise rep)
hoisted)
simplifySegOp ::
( Engine.SimplifiableRep rep,
BodyDec rep ~ (),
Engine.Simplifiable lvl
) =>
SegOp lvl (Wise rep) ->
Engine.SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp :: forall rep lvl.
(SimplifiableRep rep, BodyDec rep ~ (), Simplifiable lvl) =>
SegOp lvl (Wise rep)
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Wise rep)
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
(KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl' SegSpace
space' [Type]
ts' KernelBody (Wise rep)
kbody',
Stms (Wise rep)
body_hoisted
)
simplifySegOp (SegRed lvl
lvl SegSpace
space [SegBinOp (Wise rep)]
reds [Type]
ts KernelBody (Wise rep)
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
([SegBinOp (Wise rep)]
reds', [Stms (Wise rep)]
reds_hoisted) <-
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) forall a b. (a -> b) -> a -> b
$
forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall rep.
SimplifiableRep rep =>
VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp (SegSpace -> VName
segFlat SegSpace
space)) [SegBinOp (Wise rep)]
reds
(KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl' SegSpace
space' [SegBinOp (Wise rep)]
reds' [Type]
ts' KernelBody (Wise rep)
kbody',
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
reds_hoisted forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
)
where
scope :: Scope (Wise rep)
scope = forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
scope_vtable :: SymbolTable (Wise rep)
scope_vtable = forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
simplifySegOp (SegScan lvl
lvl SegSpace
space [SegBinOp (Wise rep)]
scans [Type]
ts KernelBody (Wise rep)
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
([SegBinOp (Wise rep)]
scans', [Stms (Wise rep)]
scans_hoisted) <-
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) forall a b. (a -> b) -> a -> b
$
forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall rep.
SimplifiableRep rep =>
VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp (SegSpace -> VName
segFlat SegSpace
space)) [SegBinOp (Wise rep)]
scans
(KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl' SegSpace
space' [SegBinOp (Wise rep)]
scans' [Type]
ts' KernelBody (Wise rep)
kbody',
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
scans_hoisted forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
)
where
scope :: Scope (Wise rep)
scope = forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
scope_vtable :: SymbolTable (Wise rep)
scope_vtable = forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
simplifySegOp (SegHist lvl
lvl SegSpace
space [HistOp (Wise rep)]
ops [Type]
ts KernelBody (Wise rep)
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
([HistOp (Wise rep)]
ops', [Stms (Wise rep)]
ops_hoisted) <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Wise rep)]
ops forall a b. (a -> b) -> a -> b
$
\(HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
dims Lambda (Wise rep)
lam) -> do
ShapeBase SubExp
w' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
w
SubExp
rf' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
rf
[VName]
arrs' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
arrs
[SubExp]
nes' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
ShapeBase SubExp
dims' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
dims
(Lambda (Wise rep)
lam', Stms (Wise rep)
op_hoisted) <-
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) forall a b. (a -> b) -> a -> b
$
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True}) forall a b. (a -> b) -> a -> b
$
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda (VName -> Names
oneName (SegSpace -> VName
segFlat SegSpace
space)) Lambda (Wise rep)
lam
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall rep.
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda rep
-> HistOp rep
HistOp ShapeBase SubExp
w' SubExp
rf' [VName]
arrs' [SubExp]
nes' ShapeBase SubExp
dims' Lambda (Wise rep)
lam',
Stms (Wise rep)
op_hoisted
)
(KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl' SegSpace
space' [HistOp (Wise rep)]
ops' [Type]
ts' KernelBody (Wise rep)
kbody',
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
ops_hoisted forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
)
where
scope :: Scope (Wise rep)
scope = forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
scope_vtable :: SymbolTable (Wise rep)
scope_vtable = forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
class HasSegOp rep where
type SegOpLevel rep
asSegOp :: Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
segOp :: SegOp (SegOpLevel rep) rep -> Op rep
segOpRules ::
(HasSegOp rep, BuilderOps rep, Buildable rep, Aliased rep) =>
RuleBook rep
segOpRules :: forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep, Aliased rep) =>
RuleBook rep
segOpRules =
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
TopDownRuleOp rep
segOpRuleTopDown] [forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(HasSegOp rep, BuilderOps rep, Aliased rep) =>
BottomUpRuleOp rep
segOpRuleBottomUp]
segOpRuleTopDown ::
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
TopDownRuleOp rep
segOpRuleTopDown :: forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
TopDownRuleOp rep
segOpRuleTopDown TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec Op rep
op
| Just SegOp (SegOpLevel rep) rep
op' <- forall rep.
HasSegOp rep =>
Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
asSegOp Op rep
op =
forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
SymbolTable rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
topDownSegOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
op'
| Bool
otherwise =
forall rep. Rule rep
Skip
segOpRuleBottomUp ::
(HasSegOp rep, BuilderOps rep, Aliased rep) =>
BottomUpRuleOp rep
segOpRuleBottomUp :: forall rep.
(HasSegOp rep, BuilderOps rep, Aliased rep) =>
BottomUpRuleOp rep
segOpRuleBottomUp BottomUp rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec Op rep
op
| Just SegOp (SegOpLevel rep) rep
op' <- forall rep.
HasSegOp rep =>
Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
asSegOp Op rep
op =
forall rep.
(Aliased rep, HasSegOp rep, BuilderOps rep) =>
(SymbolTable rep, UsageTable)
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
bottomUpSegOp BottomUp rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
op'
| Bool
otherwise =
forall rep. Rule rep
Skip
topDownSegOp ::
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
ST.SymbolTable rep ->
Pat (LetDec rep) ->
StmAux (ExpDec rep) ->
SegOp (SegOpLevel rep) rep ->
Rule rep
topDownSegOp :: forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
SymbolTable rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
topDownSegOp SymbolTable rep
vtable (Pat [PatElem (LetDec rep)]
kpes) StmAux (ExpDec rep)
dec (SegMap SegOpLevel rep
lvl SegSpace
space [Type]
ts (KernelBody BodyDec rep
_ Stms rep
kstms [KernelResult]
kres)) = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
([Type]
ts', [PatElem (LetDec rep)]
kpes', [KernelResult]
kres') <-
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool
checkForInvarianceResult (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
ts [PatElem (LetDec rep)]
kpes [KernelResult]
kres)
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([KernelResult]
kres forall a. Eq a => a -> a -> Bool
== [KernelResult]
kres') forall rep a. RuleM rep a
cannotSimplify
KernelBody rep
kbody <- forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms rep
kstms [KernelResult]
kres'
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
kpes') StmAux (ExpDec rep)
dec forall a b. (a -> b) -> a -> b
$
forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp forall a b. (a -> b) -> a -> b
$
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space [Type]
ts' KernelBody rep
kbody
where
isInvariant :: SubExp -> Bool
isInvariant Constant {} = Bool
True
isInvariant (Var VName
v) = forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
v SymbolTable rep
vtable
checkForInvarianceResult :: (Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool
checkForInvarianceResult (Type
_, PatElem (LetDec rep)
pe, Returns ResultManifest
rm Certs
cs SubExp
se)
| Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty,
ResultManifest
rm forall a. Eq a => a -> a -> Bool
== ResultManifest
ResultMaySimplify,
SubExp -> Bool
isInvariant SubExp
se = do
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] forall a b. (a -> b) -> a -> b
$
forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
ShapeBase SubExp -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space) SubExp
se
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
checkForInvarianceResult (Type, PatElem (LetDec rep), KernelResult)
_ =
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
topDownSegOp SymbolTable rep
_ (Pat [PatElem (LetDec rep)]
pes) StmAux (ExpDec rep)
_ (SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
ts KernelBody rep
kbody)
| forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegBinOp rep]
ops forall a. Ord a => a -> a -> Bool
> Int
1,
[[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings <-
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy forall {rep} {b} {rep} {b}.
(SegBinOp rep, b) -> (SegBinOp rep, b) -> Bool
sameShape forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp rep]
ops forall a b. (a -> b) -> a -> b
$
forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp rep]
ops) forall a b. (a -> b) -> a -> b
$
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
red_pes [Type]
red_ts [KernelResult]
red_res,
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((forall a. Ord a => a -> a -> Bool
> Int
1) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (t :: * -> *) a. Foldable t => t a -> Int
length) [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
let ([SegBinOp rep]
ops', [[(PatElem (LetDec rep), Type, KernelResult)]]
aux) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {rep} {a}.
Buildable rep =>
[(SegBinOp rep, [a])] -> Maybe (SegBinOp rep, [a])
combineOps [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings
([PatElem (LetDec rep)]
red_pes', [Type]
red_ts', [KernelResult]
red_res') = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(PatElem (LetDec rep), Type, KernelResult)]]
aux
pes' :: [PatElem (LetDec rep)]
pes' = [PatElem (LetDec rep)]
red_pes' forall a. [a] -> [a] -> [a]
++ [PatElem (LetDec rep)]
map_pes
ts' :: [Type]
ts' = [Type]
red_ts' forall a. [a] -> [a] -> [a]
++ [Type]
map_ts
kbody' :: KernelBody rep
kbody' = KernelBody rep
kbody {kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult]
red_res' forall a. [a] -> [a] -> [a]
++ [KernelResult]
map_res}
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pes') forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops' [Type]
ts' KernelBody rep
kbody'
where
([PatElem (LetDec rep)]
red_pes, [PatElem (LetDec rep)]
map_pes) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) [PatElem (LetDec rep)]
pes
([Type]
red_ts, [Type]
map_ts) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) [Type]
ts
([KernelResult]
red_res, [KernelResult]
map_res) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
sameShape :: (SegBinOp rep, b) -> (SegBinOp rep, b) -> Bool
sameShape (SegBinOp rep
op1, b
_) (SegBinOp rep
op2, b
_) = forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op1 forall a. Eq a => a -> a -> Bool
== forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op2
combineOps :: [(SegBinOp rep, [a])] -> Maybe (SegBinOp rep, [a])
combineOps [] = forall a. Maybe a
Nothing
combineOps ((SegBinOp rep, [a])
x : [(SegBinOp rep, [a])]
xs) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {rep} {a}.
Buildable rep =>
(SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
combine (SegBinOp rep, [a])
x [(SegBinOp rep, [a])]
xs
combine :: (SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
combine (SegBinOp rep
op1, [a]
op1_aux) (SegBinOp rep
op2, [a]
op2_aux) =
let lam1 :: Lambda rep
lam1 = forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op1
lam2 :: Lambda rep
lam2 = forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op2
([Param (LParamInfo rep)]
op1_xparams, [Param (LParamInfo rep)]
op1_yparams) =
forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op1)) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam1
([Param (LParamInfo rep)]
op2_xparams, [Param (LParamInfo rep)]
op2_yparams) =
forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op2)) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam2
lam :: Lambda rep
lam =
Lambda
{ lambdaParams :: [Param (LParamInfo rep)]
lambdaParams =
[Param (LParamInfo rep)]
op1_xparams
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op2_xparams
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op1_yparams
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op2_yparams,
lambdaReturnType :: [Type]
lambdaReturnType = forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam1 forall a. [a] -> [a] -> [a]
++ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam2,
lambdaBody :: Body rep
lambdaBody =
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (forall rep. Body rep -> Stms rep
bodyStms (forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam1) forall a. Semigroup a => a -> a -> a
<> forall rep. Body rep -> Stms rep
bodyStms (forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam2)) forall a b. (a -> b) -> a -> b
$
forall rep. Body rep -> Result
bodyResult (forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam1) forall a. Semigroup a => a -> a -> a
<> forall rep. Body rep -> Result
bodyResult (forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam2)
}
in ( SegBinOp
{ segBinOpComm :: Commutativity
segBinOpComm = forall rep. SegBinOp rep -> Commutativity
segBinOpComm SegBinOp rep
op1 forall a. Semigroup a => a -> a -> a
<> forall rep. SegBinOp rep -> Commutativity
segBinOpComm SegBinOp rep
op2,
segBinOpLambda :: Lambda rep
segBinOpLambda = Lambda rep
lam,
segBinOpNeutral :: [SubExp]
segBinOpNeutral = forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op1 forall a. [a] -> [a] -> [a]
++ forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op2,
segBinOpShape :: ShapeBase SubExp
segBinOpShape = forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op1
},
[a]
op1_aux forall a. [a] -> [a] -> [a]
++ [a]
op2_aux
)
topDownSegOp SymbolTable rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ SegOp (SegOpLevel rep) rep
_ = forall rep. Rule rep
Skip
segOpGuts ::
SegOp (SegOpLevel rep) rep ->
( [Type],
KernelBody rep,
Int,
[Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
)
segOpGuts :: forall rep.
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
[Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
segOpGuts (SegMap SegOpLevel rep
lvl SegSpace
space [Type]
kts KernelBody rep
body) =
([Type]
kts, KernelBody rep
body, Int
0, forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space)
segOpGuts (SegScan SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
kts KernelBody rep
body) =
([Type]
kts, KernelBody rep
body, forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops, forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops)
segOpGuts (SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
kts KernelBody rep
body) =
([Type]
kts, KernelBody rep
body, forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops, forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops)
segOpGuts (SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops [Type]
kts KernelBody rep
body) =
([Type]
kts, KernelBody rep
body, forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. HistOp rep -> [VName]
histDest) [HistOp rep]
ops, forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops)
bottomUpSegOp ::
(Aliased rep, HasSegOp rep, BuilderOps rep) =>
(ST.SymbolTable rep, UT.UsageTable) ->
Pat (LetDec rep) ->
StmAux (ExpDec rep) ->
SegOp (SegOpLevel rep) rep ->
Rule rep
bottomUpSegOp :: forall rep.
(Aliased rep, HasSegOp rep, BuilderOps rep) =>
(SymbolTable rep, UsageTable)
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
bottomUpSegOp (SymbolTable rep
vtable, UsageTable
used) (Pat [PatElem (LetDec rep)]
kpes) StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
segop = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') <-
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stm rep
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
distribute ([PatElem (LetDec rep)]
kpes, [Type]
kts, [KernelResult]
kres, forall a. Monoid a => a
mempty) Stms rep
kstms
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
([PatElem (LetDec rep)]
kpes' forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
kpes)
forall rep a. RuleM rep a
cannotSimplify
KernelBody rep
kbody' <-
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms rep
kstms' [KernelResult]
kres'
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
kpes') StmAux (ExpDec rep)
dec forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp forall a b. (a -> b) -> a -> b
$ [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop [Type]
kts' KernelBody rep
kbody'
where
([Type]
kts, kbody :: KernelBody rep
kbody@(KernelBody BodyDec rep
_ Stms rep
kstms [KernelResult]
kres), Int
num_nonmap_results, [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop) =
forall rep.
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
[Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
segOpGuts SegOp (SegOpLevel rep) rep
segop
free_in_kstms :: Names
free_in_kstms = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> Names
freeIn Stms rep
kstms
consumed_in_segop :: Names
consumed_in_segop = forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
space :: SegSpace
space = forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp (SegOpLevel rep) rep
segop
sliceWithGtidsFixed :: Stm rep -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm rep
stm
| Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (BasicOp (Index VName
arr Slice SubExp
slice)) <- Stm rep
stm,
[DimIndex SubExp]
space_slice <- forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> SubExp
Var forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space,
[DimIndex SubExp]
space_slice forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice,
Slice SubExp
remaining_slice <- forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
space_slice) (forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice),
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Maybe a -> Bool
isJust forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup SymbolTable rep
vtable) forall a b. (a -> b) -> a -> b
$
Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$
forall a. FreeIn a => a -> Names
freeIn VName
arr forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn Slice SubExp
remaining_slice =
forall a. a -> Maybe a
Just (Slice SubExp
remaining_slice, VName
arr)
| Bool
otherwise =
forall a. Maybe a
Nothing
distribute :: ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stm rep
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
distribute ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') Stm rep
stm
| Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ Exp rep
_ <- Stm rep
stm,
Just (Slice [DimIndex SubExp]
remaining_slice, VName
arr) <- Stm rep -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm rep
stm,
Just (PatElem (LetDec rep)
kpe, [PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> PatElem (LetDec rep)
-> Maybe
(PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
[KernelResult])
isResult [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' PatElem (LetDec rep)
pe = do
let outer_slice :: [DimIndex SubExp]
outer_slice =
forall a b. (a -> b) -> [a] -> [b]
map
( \SubExp
d ->
forall d. d -> d -> d -> DimIndex d
DimSlice
(forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))
SubExp
d
(forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))
)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
index :: PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe' =
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe'] forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. BasicOp -> Exp rep
BasicOp forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$
forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
[DimIndex SubExp]
outer_slice forall a. Semigroup a => a -> a -> a
<> [DimIndex SubExp]
remaining_slice
if forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe
VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
Bool -> Bool -> Bool
|| VName
arr
VName -> Names -> Bool
`nameIn` Names
consumed_in_segop
then do
VName
precopy <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe) forall a. Semigroup a => a -> a -> a
<> String
"_precopy"
PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe {patElemName :: VName
patElemName = VName
precopy}
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
precopy
else PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( [PatElem (LetDec rep)]
kpes'',
[Type]
kts'',
[KernelResult]
kres'',
if forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> Names -> Bool
`nameIn` Names
free_in_kstms
then Stms rep
kstms' forall a. Semigroup a => a -> a -> a
<> forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm
else Stms rep
kstms'
)
distribute ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') Stm rep
stm =
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms' forall a. Semigroup a => a -> a -> a
<> forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm)
isResult :: [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> PatElem (LetDec rep)
-> Maybe
(PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
[KernelResult])
isResult [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' PatElem (LetDec rep)
pe =
case forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (PatElem (LetDec rep), Type, KernelResult) -> Bool
matches forall a b. (a -> b) -> a -> b
$ forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' of
([(PatElem (LetDec rep)
kpe, Type
_, KernelResult
_)], [(PatElem (LetDec rep), Type, KernelResult)]
kpes_and_kres)
| Just Int
i <- forall a. Eq a => a -> [a] -> Maybe Int
elemIndex PatElem (LetDec rep)
kpe [PatElem (LetDec rep)]
kpes,
Int
i forall a. Ord a => a -> a -> Bool
>= Int
num_nonmap_results,
([PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(PatElem (LetDec rep), Type, KernelResult)]
kpes_and_kres ->
forall a. a -> Maybe a
Just (PatElem (LetDec rep)
kpe, [PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'')
([(PatElem (LetDec rep), Type, KernelResult)],
[(PatElem (LetDec rep), Type, KernelResult)])
_ -> forall a. Maybe a
Nothing
where
matches :: (PatElem (LetDec rep), Type, KernelResult) -> Bool
matches (PatElem (LetDec rep)
_, Type
_, Returns ResultManifest
_ Certs
_ (Var VName
v)) = VName
v forall a. Eq a => a -> a -> Bool
== forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe
matches (PatElem (LetDec rep), Type, KernelResult)
_ = Bool
False
kernelBodyReturns ::
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep ->
[ExpReturns] ->
m [ExpReturns]
kernelBodyReturns :: forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns = forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {rep} {inner :: * -> *} {m :: * -> *}.
(FParamInfo rep ~ FParamMem, LParamInfo rep ~ LParamMem,
RetType rep ~ RetTypeMem, BranchType rep ~ BranchTypeMem,
OpC rep ~ MemOp inner, Monad m, HasScope rep m,
HasLetDecMem (LetDec rep), ASTRep rep, OpReturns (inner rep),
RephraseOp inner) =>
KernelResult -> ExpReturns -> m ExpReturns
correct forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult
where
correct :: KernelResult -> ExpReturns -> m ExpReturns
correct (WriteReturns Certs
_ ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) ExpReturns
_ = forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns VName
arr
correct KernelResult
_ ExpReturns
ret = forall (f :: * -> *) a. Applicative f => a -> f a
pure ExpReturns
ret
segOpReturns ::
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl somerep ->
m [ExpReturns]
segOpReturns :: forall rep (inner :: * -> *) (m :: * -> *) lvl somerep.
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl somerep -> m [ExpReturns]
segOpReturns k :: SegOp lvl somerep
k@(SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody somerep
kbody) =
forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody somerep
kbody forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl somerep
k
segOpReturns k :: SegOp lvl somerep
k@(SegRed lvl
_ SegSpace
_ [SegBinOp somerep]
_ [Type]
_ KernelBody somerep
kbody) =
forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody somerep
kbody forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl somerep
k
segOpReturns k :: SegOp lvl somerep
k@(SegScan lvl
_ SegSpace
_ [SegBinOp somerep]
_ [Type]
_ KernelBody somerep
kbody) =
forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody somerep
kbody forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl somerep
k
segOpReturns (SegHist lvl
_ SegSpace
_ [HistOp somerep]
ops [Type]
_ KernelBody somerep
_) =
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. HistOp rep -> [VName]
histDest) [HistOp somerep]
ops