{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Futhark.IR.SegOp
( SegOp (..),
SegVirt (..),
SegSeqDims (..),
segLevel,
segBody,
segSpace,
typeCheckSegOp,
SegSpace (..),
scopeOfSegSpace,
segSpaceDims,
HistOp (..),
histType,
splitHistResults,
SegBinOp (..),
segBinOpResults,
segBinOpChunks,
KernelBody (..),
aliasAnalyseKernelBody,
consumedInKernelBody,
ResultManifest (..),
KernelResult (..),
kernelResultCerts,
kernelResultSubExp,
SegOpMapper (..),
identitySegOpMapper,
mapSegOpM,
traverseSegOpStms,
simplifySegOp,
HasSegOp (..),
segOpRules,
segOpReturns,
)
where
import Control.Category
import Control.Monad.Identity hiding (mapM_)
import Control.Monad.Reader hiding (mapM_)
import Control.Monad.State.Strict
import Control.Monad.Writer hiding (mapM_)
import Data.Bifunctor (first)
import Data.Bitraversable
import Data.Foldable (traverse_)
import Data.List
( elemIndex,
foldl',
groupBy,
intersperse,
isPrefixOf,
partition,
)
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR
import Futhark.IR.Aliases
( Aliases,
removeLambdaAliases,
removeStmAliases,
)
import Futhark.IR.Mem
import Futhark.IR.Prop.Aliases
import Futhark.IR.TypeCheck qualified as TC
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (chunks, maybeNth)
import Futhark.Util.Pretty
( Doc,
Pretty,
apply,
hsep,
parens,
ppTuple',
pretty,
(<+>),
(</>),
)
import Futhark.Util.Pretty qualified as PP
import Prelude hiding (id, (.))
data HistOp rep = HistOp
{ forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histShape :: Shape,
forall {k} (rep :: k). HistOp rep -> SubExp
histRaceFactor :: SubExp,
forall {k} (rep :: k). HistOp rep -> [VName]
histDest :: [VName],
forall {k} (rep :: k). HistOp rep -> [SubExp]
histNeutral :: [SubExp],
forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histOpShape :: Shape,
forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp :: Lambda rep
}
deriving (HistOp rep -> HistOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Bool
/= :: HistOp rep -> HistOp rep -> Bool
$c/= :: forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Bool
== :: HistOp rep -> HistOp rep -> Bool
$c== :: forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Bool
Eq, HistOp rep -> HistOp rep -> Bool
HistOp rep -> HistOp rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall k (rep :: k). RepTypes rep => Eq (HistOp rep)
forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Bool
forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Ordering
forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> HistOp rep
min :: HistOp rep -> HistOp rep -> HistOp rep
$cmin :: forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> HistOp rep
max :: HistOp rep -> HistOp rep -> HistOp rep
$cmax :: forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> HistOp rep
>= :: HistOp rep -> HistOp rep -> Bool
$c>= :: forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Bool
> :: HistOp rep -> HistOp rep -> Bool
$c> :: forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Bool
<= :: HistOp rep -> HistOp rep -> Bool
$c<= :: forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Bool
< :: HistOp rep -> HistOp rep -> Bool
$c< :: forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Bool
compare :: HistOp rep -> HistOp rep -> Ordering
$ccompare :: forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Ordering
Ord, Int -> HistOp rep -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k (rep :: k). RepTypes rep => Int -> HistOp rep -> ShowS
forall k (rep :: k). RepTypes rep => [HistOp rep] -> ShowS
forall k (rep :: k). RepTypes rep => HistOp rep -> String
showList :: [HistOp rep] -> ShowS
$cshowList :: forall k (rep :: k). RepTypes rep => [HistOp rep] -> ShowS
show :: HistOp rep -> String
$cshow :: forall k (rep :: k). RepTypes rep => HistOp rep -> String
showsPrec :: Int -> HistOp rep -> ShowS
$cshowsPrec :: forall k (rep :: k). RepTypes rep => Int -> HistOp rep -> ShowS
Show)
histType :: HistOp rep -> [Type]
histType :: forall {k} (rep :: k). HistOp rep -> [Type]
histType HistOp rep
op =
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` (forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histShape HistOp rep
op forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histOpShape HistOp rep
op)) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp rep
op
splitHistResults :: [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults :: forall {k} (rep :: k).
[HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp rep]
ops [SubExp]
res =
let ranks :: [Int]
ranks = forall a b. (a -> b) -> [a] -> [b]
map (forall a. ArrayShape a => a -> Int
shapeRank forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histShape) [HistOp rep]
ops
([SubExp]
idxs, [SubExp]
vals) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
ranks) [SubExp]
res
in forall a b. [a] -> [b] -> [(a, b)]
zip
(forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ranks [SubExp]
idxs)
(forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). HistOp rep -> [VName]
histDest) [HistOp rep]
ops) [SubExp]
vals)
data SegBinOp rep = SegBinOp
{ forall {k} (rep :: k). SegBinOp rep -> Commutativity
segBinOpComm :: Commutativity,
forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda :: Lambda rep,
forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral :: [SubExp],
forall {k} (rep :: k). SegBinOp rep -> ShapeBase SubExp
segBinOpShape :: Shape
}
deriving (SegBinOp rep -> SegBinOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (rep :: k).
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Bool
/= :: SegBinOp rep -> SegBinOp rep -> Bool
$c/= :: forall k (rep :: k).
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Bool
== :: SegBinOp rep -> SegBinOp rep -> Bool
$c== :: forall k (rep :: k).
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Bool
Eq, SegBinOp rep -> SegBinOp rep -> Bool
SegBinOp rep -> SegBinOp rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall k (rep :: k). RepTypes rep => Eq (SegBinOp rep)
forall k (rep :: k).
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Bool
forall k (rep :: k).
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Ordering
forall k (rep :: k).
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
min :: SegBinOp rep -> SegBinOp rep -> SegBinOp rep
$cmin :: forall k (rep :: k).
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
max :: SegBinOp rep -> SegBinOp rep -> SegBinOp rep
$cmax :: forall k (rep :: k).
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
>= :: SegBinOp rep -> SegBinOp rep -> Bool
$c>= :: forall k (rep :: k).
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Bool
> :: SegBinOp rep -> SegBinOp rep -> Bool
$c> :: forall k (rep :: k).
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Bool
<= :: SegBinOp rep -> SegBinOp rep -> Bool
$c<= :: forall k (rep :: k).
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Bool
< :: SegBinOp rep -> SegBinOp rep -> Bool
$c< :: forall k (rep :: k).
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Bool
compare :: SegBinOp rep -> SegBinOp rep -> Ordering
$ccompare :: forall k (rep :: k).
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Ordering
Ord, Int -> SegBinOp rep -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k (rep :: k). RepTypes rep => Int -> SegBinOp rep -> ShowS
forall k (rep :: k). RepTypes rep => [SegBinOp rep] -> ShowS
forall k (rep :: k). RepTypes rep => SegBinOp rep -> String
showList :: [SegBinOp rep] -> ShowS
$cshowList :: forall k (rep :: k). RepTypes rep => [SegBinOp rep] -> ShowS
show :: SegBinOp rep -> String
$cshow :: forall k (rep :: k). RepTypes rep => SegBinOp rep -> String
showsPrec :: Int -> SegBinOp rep -> ShowS
$cshowsPrec :: forall k (rep :: k). RepTypes rep => Int -> SegBinOp rep -> ShowS
Show)
segBinOpResults :: [SegBinOp rep] -> Int
segBinOpResults :: forall {k} (rep :: k). [SegBinOp rep] -> Int
segBinOpResults = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral)
segBinOpChunks :: [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks :: forall {k} (rep :: k) a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks = forall a. [Int] -> [a] -> [[a]]
chunks forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral)
data KernelBody rep = KernelBody
{ forall {k} (rep :: k). KernelBody rep -> BodyDec rep
kernelBodyDec :: BodyDec rep,
forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms :: Stms rep,
forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult :: [KernelResult]
}
deriving instance RepTypes rep => Ord (KernelBody rep)
deriving instance RepTypes rep => Show (KernelBody rep)
deriving instance RepTypes rep => Eq (KernelBody rep)
data ResultManifest
=
ResultNoSimplify
|
ResultMaySimplify
|
ResultPrivate
deriving (ResultManifest -> ResultManifest -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ResultManifest -> ResultManifest -> Bool
$c/= :: ResultManifest -> ResultManifest -> Bool
== :: ResultManifest -> ResultManifest -> Bool
$c== :: ResultManifest -> ResultManifest -> Bool
Eq, Int -> ResultManifest -> ShowS
[ResultManifest] -> ShowS
ResultManifest -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ResultManifest] -> ShowS
$cshowList :: [ResultManifest] -> ShowS
show :: ResultManifest -> String
$cshow :: ResultManifest -> String
showsPrec :: Int -> ResultManifest -> ShowS
$cshowsPrec :: Int -> ResultManifest -> ShowS
Show, Eq ResultManifest
ResultManifest -> ResultManifest -> Bool
ResultManifest -> ResultManifest -> Ordering
ResultManifest -> ResultManifest -> ResultManifest
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ResultManifest -> ResultManifest -> ResultManifest
$cmin :: ResultManifest -> ResultManifest -> ResultManifest
max :: ResultManifest -> ResultManifest -> ResultManifest
$cmax :: ResultManifest -> ResultManifest -> ResultManifest
>= :: ResultManifest -> ResultManifest -> Bool
$c>= :: ResultManifest -> ResultManifest -> Bool
> :: ResultManifest -> ResultManifest -> Bool
$c> :: ResultManifest -> ResultManifest -> Bool
<= :: ResultManifest -> ResultManifest -> Bool
$c<= :: ResultManifest -> ResultManifest -> Bool
< :: ResultManifest -> ResultManifest -> Bool
$c< :: ResultManifest -> ResultManifest -> Bool
compare :: ResultManifest -> ResultManifest -> Ordering
$ccompare :: ResultManifest -> ResultManifest -> Ordering
Ord)
data KernelResult
=
Returns ResultManifest Certs SubExp
| WriteReturns
Certs
Shape
VName
[(Slice SubExp, SubExp)]
| TileReturns
Certs
[(SubExp, SubExp)]
VName
| RegTileReturns
Certs
[ ( SubExp,
SubExp,
SubExp
)
]
VName
deriving (KernelResult -> KernelResult -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelResult -> KernelResult -> Bool
$c/= :: KernelResult -> KernelResult -> Bool
== :: KernelResult -> KernelResult -> Bool
$c== :: KernelResult -> KernelResult -> Bool
Eq, Int -> KernelResult -> ShowS
[KernelResult] -> ShowS
KernelResult -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelResult] -> ShowS
$cshowList :: [KernelResult] -> ShowS
show :: KernelResult -> String
$cshow :: KernelResult -> String
showsPrec :: Int -> KernelResult -> ShowS
$cshowsPrec :: Int -> KernelResult -> ShowS
Show, Eq KernelResult
KernelResult -> KernelResult -> Bool
KernelResult -> KernelResult -> Ordering
KernelResult -> KernelResult -> KernelResult
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelResult -> KernelResult -> KernelResult
$cmin :: KernelResult -> KernelResult -> KernelResult
max :: KernelResult -> KernelResult -> KernelResult
$cmax :: KernelResult -> KernelResult -> KernelResult
>= :: KernelResult -> KernelResult -> Bool
$c>= :: KernelResult -> KernelResult -> Bool
> :: KernelResult -> KernelResult -> Bool
$c> :: KernelResult -> KernelResult -> Bool
<= :: KernelResult -> KernelResult -> Bool
$c<= :: KernelResult -> KernelResult -> Bool
< :: KernelResult -> KernelResult -> Bool
$c< :: KernelResult -> KernelResult -> Bool
compare :: KernelResult -> KernelResult -> Ordering
$ccompare :: KernelResult -> KernelResult -> Ordering
Ord)
kernelResultCerts :: KernelResult -> Certs
kernelResultCerts :: KernelResult -> Certs
kernelResultCerts (Returns ResultManifest
_ Certs
cs SubExp
_) = Certs
cs
kernelResultCerts (WriteReturns Certs
cs ShapeBase SubExp
_ VName
_ [(Slice SubExp, SubExp)]
_) = Certs
cs
kernelResultCerts (TileReturns Certs
cs [(SubExp, SubExp)]
_ VName
_) = Certs
cs
kernelResultCerts (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
_ VName
_) = Certs
cs
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp (Returns ResultManifest
_ Certs
_ SubExp
se) = SubExp
se
kernelResultSubExp (WriteReturns Certs
_ ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) = VName -> SubExp
Var VName
arr
kernelResultSubExp (TileReturns Certs
_ [(SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v
kernelResultSubExp (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v
instance FreeIn KernelResult where
freeIn' :: KernelResult -> FV
freeIn' (Returns ResultManifest
_ Certs
cs SubExp
what) = forall a. FreeIn a => a -> FV
freeIn' Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' SubExp
what
freeIn' (WriteReturns Certs
cs ShapeBase SubExp
rws VName
arr [(Slice SubExp, SubExp)]
res) = forall a. FreeIn a => a -> FV
freeIn' Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' ShapeBase SubExp
rws forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' VName
arr forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [(Slice SubExp, SubExp)]
res
freeIn' (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
forall a. FreeIn a => a -> FV
freeIn' Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp)]
dims forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' VName
v
freeIn' (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
forall a. FreeIn a => a -> FV
freeIn' Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp, SubExp)]
dims_n_tiles forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' VName
v
instance ASTRep rep => FreeIn (KernelBody rep) where
freeIn' :: KernelBody rep -> FV
freeIn' (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
Names -> FV -> FV
fvBind Names
bound_in_stms forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> FV
freeIn' BodyDec rep
dec forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' Stms rep
stms forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [KernelResult]
res
where
bound_in_stms :: Names
bound_in_stms = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall {k} (rep :: k). Stm rep -> Names
boundByStm Stms rep
stms
instance ASTRep rep => Substitute (KernelBody rep) where
substituteNames :: Map VName VName -> KernelBody rep -> KernelBody rep
substituteNames Map VName VName
subst (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst BodyDec rep
dec)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Stms rep
stms)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [KernelResult]
res)
instance Substitute KernelResult where
substituteNames :: Map VName VName -> KernelResult -> KernelResult
substituteNames Map VName VName
subst (Returns ResultManifest
manifest Certs
cs SubExp
se) =
ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs) (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
se)
substituteNames Map VName VName
subst (WriteReturns Certs
cs ShapeBase SubExp
rws VName
arr [(Slice SubExp, SubExp)]
res) =
Certs
-> ShapeBase SubExp
-> VName
-> [(Slice SubExp, SubExp)]
-> KernelResult
WriteReturns
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ShapeBase SubExp
rws)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
arr)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(Slice SubExp, SubExp)]
res)
substituteNames Map VName VName
subst (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
Certs -> [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp)]
dims)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
substituteNames Map VName VName
subst (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp, SubExp)]
dims_n_tiles)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
instance ASTRep rep => Rename (KernelBody rep) where
rename :: KernelBody rep -> RenameM (KernelBody rep)
rename (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) = do
BodyDec rep
dec' <- forall a. Rename a => a -> RenameM a
rename BodyDec rep
dec
forall {k} (rep :: k) a.
Renameable rep =>
Stms rep -> (Stms rep -> RenameM a) -> RenameM a
renamingStms Stms rep
stms forall a b. (a -> b) -> a -> b
$ \Stms rep
stms' ->
forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec' Stms rep
stms' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename [KernelResult]
res
instance Rename KernelResult where
rename :: KernelResult -> RenameM KernelResult
rename = forall a. Substitute a => a -> RenameM a
substituteRename
aliasAnalyseKernelBody ::
( ASTRep rep,
CanBeAliased (Op rep)
) =>
AliasTable ->
KernelBody rep ->
KernelBody (Aliases rep)
aliasAnalyseKernelBody :: forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
aliases (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
let Body BodyDec (Aliases rep)
dec' Stms (Aliases rep)
stms' Result
_ = forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Body rep -> Body (Aliases rep)
Alias.analyseBody AliasTable
aliases forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms []
in forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Aliases rep)
dec' Stms (Aliases rep)
stms' [KernelResult]
res
removeKernelBodyAliases ::
CanBeAliased (Op rep) =>
KernelBody (Aliases rep) ->
KernelBody rep
removeKernelBodyAliases :: forall {k} (rep :: k).
CanBeAliased (Op rep) =>
KernelBody (Aliases rep) -> KernelBody rep
removeKernelBodyAliases (KernelBody (BodyAliasing
_, BodyDec rep
dec) Stms (Aliases rep)
stms [KernelResult]
res) =
forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (rep :: k).
CanBeAliased (Op rep) =>
Stm (Aliases rep) -> Stm rep
removeStmAliases Stms (Aliases rep)
stms) [KernelResult]
res
removeKernelBodyWisdom ::
CanBeWise (Op rep) =>
KernelBody (Wise rep) ->
KernelBody rep
removeKernelBodyWisdom :: forall {k} (rep :: k).
CanBeWise (Op rep) =>
KernelBody (Wise rep) -> KernelBody rep
removeKernelBodyWisdom (KernelBody BodyDec (Wise rep)
dec Stms (Wise rep)
stms [KernelResult]
res) =
let Body BodyDec rep
dec' Stms rep
stms' Result
_ = forall {k} (rep :: k).
CanBeWise (Op rep) =>
Body (Wise rep) -> Body rep
removeBodyWisdom forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec (Wise rep)
dec Stms (Wise rep)
stms []
in forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec' Stms rep
stms' [KernelResult]
res
consumedInKernelBody ::
Aliased rep =>
KernelBody rep ->
Names
consumedInKernelBody :: forall {k} (rep :: k). Aliased rep => KernelBody rep -> Names
consumedInKernelBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
forall {k} (rep :: k). Aliased rep => Body rep -> Names
consumedInBody (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms []) forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Names
consumedByReturn [KernelResult]
res)
where
consumedByReturn :: KernelResult -> Names
consumedByReturn (WriteReturns Certs
_ ShapeBase SubExp
_ VName
a [(Slice SubExp, SubExp)]
_) = VName -> Names
oneName VName
a
consumedByReturn KernelResult
_ = forall a. Monoid a => a
mempty
checkKernelBody ::
TC.Checkable rep =>
[Type] ->
KernelBody (Aliases rep) ->
TC.TypeM rep ()
checkKernelBody :: forall {k} (rep :: k).
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts (KernelBody (BodyAliasing
_, BodyDec rep
dec) Stms (Aliases rep)
stms [KernelResult]
kres) = do
forall {k} (rep :: k). Checkable rep => BodyDec rep -> TypeM rep ()
TC.checkBodyDec BodyDec rep
dec
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {k} {rep :: k}.
Checkable rep =>
KernelResult -> TypeM rep ()
consumeKernelResult [KernelResult]
kres
forall {k} (rep :: k) a.
Checkable rep =>
Stms (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.checkStms Stms (Aliases rep)
stms forall a b. (a -> b) -> a -> b
$ do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
Text
"Kernel return type is "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
ts
forall a. Semigroup a => a -> a -> a
<> Text
", but body returns "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText (forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres)
forall a. Semigroup a => a -> a -> a
<> Text
" values."
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k} {rep :: k}.
Checkable rep =>
KernelResult -> Type -> TypeM rep ()
checkKernelResult [KernelResult]
kres [Type]
ts
where
consumeKernelResult :: KernelResult -> TypeM rep ()
consumeKernelResult (WriteReturns Certs
_ ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) =
forall {k} (rep :: k). Checkable rep => Names -> TypeM rep ()
TC.consume forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k). Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
arr
consumeKernelResult KernelResult
_ =
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
checkKernelResult :: KernelResult -> Type -> TypeM rep ()
checkKernelResult (Returns ResultManifest
_ Certs
cs SubExp
what) Type
t = do
forall {k} (rep :: k). Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [Type
t] SubExp
what
checkKernelResult (WriteReturns Certs
cs ShapeBase SubExp
shape VName
arr [(Slice SubExp, SubExp)]
res) Type
t = do
forall {k} (rep :: k). Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Slice SubExp, SubExp)]
res forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
e) -> do
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Slice SubExp
slice
forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [Type
t] SubExp
e
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t forall a. Eq a => a -> a -> Bool
== Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
Text
"WriteReturns returning "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText SubExp
e
forall a. Semigroup a => a -> a -> a
<> Text
" of type "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText Type
t
forall a. Semigroup a => a -> a -> a
<> Text
", shape="
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText ShapeBase SubExp
shape
forall a. Semigroup a => a -> a -> a
<> Text
", but destination array has type "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText Type
arr_t
checkKernelResult (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) Type
t = do
forall {k} (rep :: k). Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(SubExp, SubExp)]
dims forall a b. (a -> b) -> a -> b
$ \(SubExp
dim, SubExp
tile) -> do
forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
dim
forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
tile
Type
vt <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
v
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt forall a. Eq a => a -> a -> Bool
== Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` forall d. [d] -> ShapeBase d
Shape (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims)) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
Text
"Invalid type for TileReturns " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText VName
v
checkKernelResult (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
arr) Type
t = do
forall {k} (rep :: k). Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
dims
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
blk_tiles
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
reg_tiles
Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t forall a. Eq a => a -> a -> Bool
== Type
expected) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
Text
"Invalid type for TileReturns. Expected:\n "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText Type
expected
forall a. Semigroup a => a -> a -> a
<> Text
",\ngot:\n "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText Type
arr_t
where
([SubExp]
dims, [SubExp]
blk_tiles, [SubExp]
reg_tiles) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, SubExp)]
dims_n_tiles
expected :: Type
expected = Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` forall d. [d] -> ShapeBase d
Shape ([SubExp]
blk_tiles forall a. Semigroup a => a -> a -> a
<> [SubExp]
reg_tiles)
kernelBodyMetrics :: OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics :: forall {k} (rep :: k).
OpMetrics (Op rep) =>
KernelBody rep -> MetricsM ()
kernelBodyMetrics = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {k} (rep :: k). OpMetrics (Op rep) => Stm rep -> MetricsM ()
stmMetrics forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms
instance PrettyRep rep => Pretty (KernelBody rep) where
pretty :: forall ann. KernelBody rep -> Doc ann
pretty (KernelBody BodyDec rep
_ Stms rep
stms [KernelResult]
res) =
forall a. [Doc a] -> Doc a
PP.stack (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty (forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms rep
stms))
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"return"
forall a. Doc a -> Doc a -> Doc a
<+> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
PP.commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [KernelResult]
res)
certAnnots :: Certs -> [Doc ann]
certAnnots :: forall ann. Certs -> [Doc ann]
certAnnots Certs
cs
| Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty = []
| Bool
otherwise = [forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs]
instance Pretty KernelResult where
pretty :: forall ann. KernelResult -> Doc ann
pretty (Returns ResultManifest
ResultNoSimplify Certs
cs SubExp
what) =
forall a. [Doc a] -> Doc a
hsep forall a b. (a -> b) -> a -> b
$ forall ann. Certs -> [Doc ann]
certAnnots Certs
cs forall a. Semigroup a => a -> a -> a
<> [Doc ann
"returns (manifest)" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
what]
pretty (Returns ResultManifest
ResultPrivate Certs
cs SubExp
what) =
forall a. [Doc a] -> Doc a
hsep forall a b. (a -> b) -> a -> b
$ forall ann. Certs -> [Doc ann]
certAnnots Certs
cs forall a. Semigroup a => a -> a -> a
<> [Doc ann
"returns (private)" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
what]
pretty (Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
what) =
forall a. [Doc a] -> Doc a
hsep forall a b. (a -> b) -> a -> b
$ forall ann. Certs -> [Doc ann]
certAnnots Certs
cs forall a. Semigroup a => a -> a -> a
<> [Doc ann
"returns" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
what]
pretty (WriteReturns Certs
cs ShapeBase SubExp
shape VName
arr [(Slice SubExp, SubExp)]
res) =
forall a. [Doc a] -> Doc a
hsep forall a b. (a -> b) -> a -> b
$
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs
forall a. Semigroup a => a -> a -> a
<> [ forall a ann. Pretty a => a -> Doc ann
pretty VName
arr
forall a. Doc a -> Doc a -> Doc a
<+> forall ann. Doc ann
PP.colon
forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty ShapeBase SubExp
shape
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"with"
forall a. Doc a -> Doc a -> Doc a
<+> forall a. [Doc a] -> Doc a
PP.apply (forall a b. (a -> b) -> [a] -> [b]
map forall {a} {a} {ann}. (Pretty a, Pretty a) => (a, a) -> Doc ann
ppRes [(Slice SubExp, SubExp)]
res)
]
where
ppRes :: (a, a) -> Doc ann
ppRes (a
slice, a
e) = forall a ann. Pretty a => a -> Doc ann
pretty a
slice forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"=" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty a
e
pretty (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
forall a. [Doc a] -> Doc a
hsep forall a b. (a -> b) -> a -> b
$ forall ann. Certs -> [Doc ann]
certAnnots Certs
cs forall a. Semigroup a => a -> a -> a
<> [Doc ann
"tile" forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
apply (forall a b. (a -> b) -> [a] -> [b]
map forall {a} {a} {ann}. (Pretty a, Pretty a) => (a, a) -> Doc ann
onDim [(SubExp, SubExp)]
dims) forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty VName
v]
where
onDim :: (a, a) -> Doc ann
onDim (a
dim, a
tile) = forall a ann. Pretty a => a -> Doc ann
pretty a
dim forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"/" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty a
tile
pretty (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
forall a. [Doc a] -> Doc a
hsep forall a b. (a -> b) -> a -> b
$ forall ann. Certs -> [Doc ann]
certAnnots Certs
cs forall a. Semigroup a => a -> a -> a
<> [Doc ann
"blkreg_tile" forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
apply (forall a b. (a -> b) -> [a] -> [b]
map forall {a} {a} {a} {ann}.
(Pretty a, Pretty a, Pretty a) =>
(a, a, a) -> Doc ann
onDim [(SubExp, SubExp, SubExp)]
dims_n_tiles) forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty VName
v]
where
onDim :: (a, a, a) -> Doc ann
onDim (a
dim, a
blk_tile, a
reg_tile) =
forall a ann. Pretty a => a -> Doc ann
pretty a
dim forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"/" forall a. Doc a -> Doc a -> Doc a
<+> forall ann. Doc ann -> Doc ann
parens (forall a ann. Pretty a => a -> Doc ann
pretty a
blk_tile forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"*" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty a
reg_tile)
newtype SegSeqDims = SegSeqDims {SegSeqDims -> [Int]
segSeqDims :: [Int]}
deriving (SegSeqDims -> SegSeqDims -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegSeqDims -> SegSeqDims -> Bool
$c/= :: SegSeqDims -> SegSeqDims -> Bool
== :: SegSeqDims -> SegSeqDims -> Bool
$c== :: SegSeqDims -> SegSeqDims -> Bool
Eq, Eq SegSeqDims
SegSeqDims -> SegSeqDims -> Bool
SegSeqDims -> SegSeqDims -> Ordering
SegSeqDims -> SegSeqDims -> SegSeqDims
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegSeqDims -> SegSeqDims -> SegSeqDims
$cmin :: SegSeqDims -> SegSeqDims -> SegSeqDims
max :: SegSeqDims -> SegSeqDims -> SegSeqDims
$cmax :: SegSeqDims -> SegSeqDims -> SegSeqDims
>= :: SegSeqDims -> SegSeqDims -> Bool
$c>= :: SegSeqDims -> SegSeqDims -> Bool
> :: SegSeqDims -> SegSeqDims -> Bool
$c> :: SegSeqDims -> SegSeqDims -> Bool
<= :: SegSeqDims -> SegSeqDims -> Bool
$c<= :: SegSeqDims -> SegSeqDims -> Bool
< :: SegSeqDims -> SegSeqDims -> Bool
$c< :: SegSeqDims -> SegSeqDims -> Bool
compare :: SegSeqDims -> SegSeqDims -> Ordering
$ccompare :: SegSeqDims -> SegSeqDims -> Ordering
Ord, Int -> SegSeqDims -> ShowS
[SegSeqDims] -> ShowS
SegSeqDims -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegSeqDims] -> ShowS
$cshowList :: [SegSeqDims] -> ShowS
show :: SegSeqDims -> String
$cshow :: SegSeqDims -> String
showsPrec :: Int -> SegSeqDims -> ShowS
$cshowsPrec :: Int -> SegSeqDims -> ShowS
Show)
data SegVirt
= SegVirt
| SegNoVirt
|
SegNoVirtFull SegSeqDims
deriving (SegVirt -> SegVirt -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegVirt -> SegVirt -> Bool
$c/= :: SegVirt -> SegVirt -> Bool
== :: SegVirt -> SegVirt -> Bool
$c== :: SegVirt -> SegVirt -> Bool
Eq, Eq SegVirt
SegVirt -> SegVirt -> Bool
SegVirt -> SegVirt -> Ordering
SegVirt -> SegVirt -> SegVirt
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegVirt -> SegVirt -> SegVirt
$cmin :: SegVirt -> SegVirt -> SegVirt
max :: SegVirt -> SegVirt -> SegVirt
$cmax :: SegVirt -> SegVirt -> SegVirt
>= :: SegVirt -> SegVirt -> Bool
$c>= :: SegVirt -> SegVirt -> Bool
> :: SegVirt -> SegVirt -> Bool
$c> :: SegVirt -> SegVirt -> Bool
<= :: SegVirt -> SegVirt -> Bool
$c<= :: SegVirt -> SegVirt -> Bool
< :: SegVirt -> SegVirt -> Bool
$c< :: SegVirt -> SegVirt -> Bool
compare :: SegVirt -> SegVirt -> Ordering
$ccompare :: SegVirt -> SegVirt -> Ordering
Ord, Int -> SegVirt -> ShowS
[SegVirt] -> ShowS
SegVirt -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegVirt] -> ShowS
$cshowList :: [SegVirt] -> ShowS
show :: SegVirt -> String
$cshow :: SegVirt -> String
showsPrec :: Int -> SegVirt -> ShowS
$cshowsPrec :: Int -> SegVirt -> ShowS
Show)
data SegSpace = SegSpace
{
SegSpace -> VName
segFlat :: VName,
SegSpace -> [(VName, SubExp)]
unSegSpace :: [(VName, SubExp)]
}
deriving (SegSpace -> SegSpace -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegSpace -> SegSpace -> Bool
$c/= :: SegSpace -> SegSpace -> Bool
== :: SegSpace -> SegSpace -> Bool
$c== :: SegSpace -> SegSpace -> Bool
Eq, Eq SegSpace
SegSpace -> SegSpace -> Bool
SegSpace -> SegSpace -> Ordering
SegSpace -> SegSpace -> SegSpace
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegSpace -> SegSpace -> SegSpace
$cmin :: SegSpace -> SegSpace -> SegSpace
max :: SegSpace -> SegSpace -> SegSpace
$cmax :: SegSpace -> SegSpace -> SegSpace
>= :: SegSpace -> SegSpace -> Bool
$c>= :: SegSpace -> SegSpace -> Bool
> :: SegSpace -> SegSpace -> Bool
$c> :: SegSpace -> SegSpace -> Bool
<= :: SegSpace -> SegSpace -> Bool
$c<= :: SegSpace -> SegSpace -> Bool
< :: SegSpace -> SegSpace -> Bool
$c< :: SegSpace -> SegSpace -> Bool
compare :: SegSpace -> SegSpace -> Ordering
$ccompare :: SegSpace -> SegSpace -> Ordering
Ord, Int -> SegSpace -> ShowS
[SegSpace] -> ShowS
SegSpace -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegSpace] -> ShowS
$cshowList :: [SegSpace] -> ShowS
show :: SegSpace -> String
$cshow :: SegSpace -> String
showsPrec :: Int -> SegSpace -> ShowS
$cshowsPrec :: Int -> SegSpace -> ShowS
Show)
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims (SegSpace VName
_ [(VName, SubExp)]
space) = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
space
scopeOfSegSpace :: SegSpace -> Scope rep
scopeOfSegSpace :: forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace (SegSpace VName
phys [(VName, SubExp)]
space) =
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (VName
phys forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, SubExp)]
space) forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). IntType -> NameInfo rep
IndexName IntType
Int64
checkSegSpace :: TC.Checkable rep => SegSpace -> TC.TypeM rep ()
checkSegSpace :: forall {k} (rep :: k). Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace (SegSpace VName
_ [(VName, SubExp)]
dims) =
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> b
snd) [(VName, SubExp)]
dims
data SegOp lvl rep
= SegMap lvl SegSpace [Type] (KernelBody rep)
|
SegRed lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep)
| SegScan lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep)
| SegHist lvl SegSpace [HistOp rep] [Type] (KernelBody rep)
deriving (SegOp lvl rep -> SegOp lvl rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall lvl k (rep :: k).
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
/= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c/= :: forall lvl k (rep :: k).
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
== :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c== :: forall lvl k (rep :: k).
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
Eq, SegOp lvl rep -> SegOp lvl rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {lvl} {k} {rep :: k}.
(RepTypes rep, Ord lvl) =>
Eq (SegOp lvl rep)
forall lvl k (rep :: k).
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
forall lvl k (rep :: k).
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Ordering
forall lvl k (rep :: k).
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
min :: SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
$cmin :: forall lvl k (rep :: k).
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
max :: SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
$cmax :: forall lvl k (rep :: k).
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
>= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c>= :: forall lvl k (rep :: k).
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
> :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c> :: forall lvl k (rep :: k).
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
<= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c<= :: forall lvl k (rep :: k).
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
< :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c< :: forall lvl k (rep :: k).
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
compare :: SegOp lvl rep -> SegOp lvl rep -> Ordering
$ccompare :: forall lvl k (rep :: k).
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Ordering
Ord, Int -> SegOp lvl rep -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall lvl k (rep :: k).
(RepTypes rep, Show lvl) =>
Int -> SegOp lvl rep -> ShowS
forall lvl k (rep :: k).
(RepTypes rep, Show lvl) =>
[SegOp lvl rep] -> ShowS
forall lvl k (rep :: k).
(RepTypes rep, Show lvl) =>
SegOp lvl rep -> String
showList :: [SegOp lvl rep] -> ShowS
$cshowList :: forall lvl k (rep :: k).
(RepTypes rep, Show lvl) =>
[SegOp lvl rep] -> ShowS
show :: SegOp lvl rep -> String
$cshow :: forall lvl k (rep :: k).
(RepTypes rep, Show lvl) =>
SegOp lvl rep -> String
showsPrec :: Int -> SegOp lvl rep -> ShowS
$cshowsPrec :: forall lvl k (rep :: k).
(RepTypes rep, Show lvl) =>
Int -> SegOp lvl rep -> ShowS
Show)
segLevel :: SegOp lvl rep -> lvl
segLevel :: forall {k} lvl (rep :: k). SegOp lvl rep -> lvl
segLevel (SegMap lvl
lvl SegSpace
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegRed lvl
lvl SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegScan lvl
lvl SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegHist lvl
lvl SegSpace
_ [HistOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segSpace :: SegOp lvl rep -> SegSpace
segSpace :: forall {k} lvl (rep :: k). SegOp lvl rep -> SegSpace
segSpace (SegMap lvl
_ SegSpace
lvl [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegRed lvl
_ SegSpace
lvl [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegScan lvl
_ SegSpace
lvl [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegHist lvl
_ SegSpace
lvl [HistOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segBody :: SegOp lvl rep -> KernelBody rep
segBody :: forall {k} lvl (rep :: k). SegOp lvl rep -> KernelBody rep
segBody SegOp lvl rep
segop =
case SegOp lvl rep
segop of
SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
SegRed lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
SegScan lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
SegHist lvl
_ SegSpace
_ [HistOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
_ Type
t (WriteReturns Certs
_ ShapeBase SubExp
shape VName
_ [(Slice SubExp, SubExp)]
_) =
Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape
segResultShape SegSpace
space Type
t Returns {} =
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow) Type
t forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segResultShape SegSpace
_ Type
t (TileReturns Certs
_ [(SubExp, SubExp)]
dims VName
_) =
Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` forall d. [d] -> ShapeBase d
Shape (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims)
segResultShape SegSpace
_ Type
t (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
_) =
Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` forall d. [d] -> ShapeBase d
Shape (forall a b. (a -> b) -> [a] -> [b]
map (\(SubExp
dim, SubExp
_, SubExp
_) -> SubExp
dim) [(SubExp, SubExp, SubExp)]
dims_n_tiles)
segOpType :: SegOp lvl rep -> [Type]
segOpType :: forall {k} lvl (rep :: k). SegOp lvl rep -> [Type]
segOpType (SegMap lvl
_ SegSpace
space [Type]
ts KernelBody rep
kbody) =
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
ts forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
segOpType (SegRed lvl
_ SegSpace
space [SegBinOp rep]
reds [Type]
ts KernelBody rep
kbody) =
[Type]
red_ts
forall a. [a] -> [a] -> [a]
++ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
[Type]
map_ts
(forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody)
where
map_ts :: [Type]
map_ts = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) [Type]
ts
segment_dims :: [SubExp]
segment_dims = forall a. [a] -> [a]
init forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
red_ts :: [Type]
red_ts = do
SegBinOp rep
op <- [SegBinOp rep]
reds
let shape :: ShapeBase SubExp
shape = forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op)
segOpType (SegScan lvl
_ SegSpace
space [SegBinOp rep]
scans [Type]
ts KernelBody rep
kbody) =
[Type]
scan_ts
forall a. [a] -> [a] -> [a]
++ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
[Type]
map_ts
(forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody)
where
map_ts :: [Type]
map_ts = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) [Type]
ts
scan_ts :: [Type]
scan_ts = do
SegBinOp rep
op <- [SegBinOp rep]
scans
let shape :: ShapeBase SubExp
shape = forall d. [d] -> ShapeBase d
Shape (SegSpace -> [SubExp]
segSpaceDims SegSpace
space) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op)
segOpType (SegHist lvl
_ SegSpace
space [HistOp rep]
ops [Type]
_ KernelBody rep
_) = do
HistOp rep
op <- [HistOp rep]
ops
let shape :: ShapeBase SubExp
shape = forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histShape HistOp rep
op forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histOpShape HistOp rep
op
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp rep
op)
where
dims :: [SubExp]
dims = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segment_dims :: [SubExp]
segment_dims = forall a. [a] -> [a]
init [SubExp]
dims
instance TypedOp (SegOp lvl rep) where
opType :: forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
SegOp lvl rep -> m [ExtType]
opType = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall u. [TypeBase (ShapeBase SubExp) u] -> [TypeBase ExtShape u]
staticShapes forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} lvl (rep :: k). SegOp lvl rep -> [Type]
segOpType
instance
(ASTRep rep, Aliased rep, ASTConstraints lvl) =>
AliasedOp (SegOp lvl rep)
where
opAliases :: SegOp lvl rep -> [Names]
opAliases = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const forall a. Monoid a => a
mempty) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} lvl (rep :: k). SegOp lvl rep -> [Type]
segOpType
consumedInOp :: SegOp lvl rep -> Names
consumedInOp (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
kbody) =
forall {k} (rep :: k). Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
consumedInOp (SegRed lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
forall {k} (rep :: k). Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
consumedInOp (SegScan lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
forall {k} (rep :: k). Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
consumedInOp (SegHist lvl
_ SegSpace
_ [HistOp rep]
ops [Type]
_ KernelBody rep
kbody) =
[VName] -> Names
namesFromList (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). HistOp rep -> [VName]
histDest [HistOp rep]
ops) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
typeCheckSegOp ::
TC.Checkable rep =>
(lvl -> TC.TypeM rep ()) ->
SegOp lvl (Aliases rep) ->
TC.TypeM rep ()
typeCheckSegOp :: forall {k} (rep :: k) lvl.
Checkable rep =>
(lvl -> TypeM rep ()) -> SegOp lvl (Aliases rep) -> TypeM rep ()
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Aliases rep)
kbody) = do
lvl -> TypeM rep ()
checkLvl lvl
lvl
forall {k} (rep :: k).
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [] [Type]
ts KernelBody (Aliases rep)
kbody
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegRed lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
reds [Type]
ts KernelBody (Aliases rep)
body) = do
lvl -> TypeM rep ()
checkLvl lvl
lvl
forall {k} (rep :: k).
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
reds' [Type]
ts KernelBody (Aliases rep)
body
where
reds' :: [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
reds' =
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
(forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp (Aliases rep)]
reds)
(forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases rep)]
reds)
(forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). SegBinOp rep -> ShapeBase SubExp
segBinOpShape [SegBinOp (Aliases rep)]
reds)
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegScan lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
scans [Type]
ts KernelBody (Aliases rep)
body) = do
lvl -> TypeM rep ()
checkLvl lvl
lvl
forall {k} (rep :: k).
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
scans' [Type]
ts KernelBody (Aliases rep)
body
where
scans' :: [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
scans' =
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
(forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp (Aliases rep)]
scans)
(forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases rep)]
scans)
(forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). SegBinOp rep -> ShapeBase SubExp
segBinOpShape [SegBinOp (Aliases rep)]
scans)
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegHist lvl
lvl SegSpace
space [HistOp (Aliases rep)]
ops [Type]
ts KernelBody (Aliases rep)
kbody) = do
lvl -> TypeM rep ()
checkLvl lvl
lvl
forall {k} (rep :: k). Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace SegSpace
space
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {k} (rep :: k) u.
Checkable rep =>
TypeBase (ShapeBase SubExp) u -> TypeM rep ()
TC.checkType [Type]
ts
forall {k} (rep :: k) a.
Checkable rep =>
Scope (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.binding (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$ do
[[Type]]
nes_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Aliases rep)]
ops forall a b. (a -> b) -> a -> b
$ \(HistOp ShapeBase SubExp
dest_shape SubExp
rf [VName]
dests [SubExp]
nes ShapeBase SubExp
shape Lambda (Aliases rep)
op) -> do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ShapeBase SubExp
dest_shape
forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
rf
[Arg]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k). Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
let stripVecDims :: Type -> Type
stripVecDims = forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray forall a b. (a -> b) -> a -> b
$ forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shape
forall {k} (rep :: k).
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
op forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Arg -> Arg
TC.noArgAliases forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Type -> Type
stripVecDims) forall a b. (a -> b) -> a -> b
$ [Arg]
nes' forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
let nes_t :: [Type]
nes_t = forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t forall a. Eq a => a -> a -> Bool
== forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
Text
"SegHist operator has return type "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op)
forall a. Semigroup a => a -> a -> a
<> Text
" but neutral element has type "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
nes_t
let dest_shape' :: ShapeBase SubExp
dest_shape' = forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
dest_shape forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
shape
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
forall {k} (rep :: k).
Checkable rep =>
[Type] -> VName -> TypeM rep ()
TC.requireI [Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
dest_shape'] VName
dest
forall {k} (rep :: k). Checkable rep => Names -> TypeM rep ()
TC.consume forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k). Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
dest
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) [Type]
nes_t
forall {k} (rep :: k).
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts KernelBody (Aliases rep)
kbody
let bucket_ret_t :: [Type]
bucket_ret_t =
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((forall a. Int -> a -> [a]
`replicate` forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. ArrayShape a => a -> Int
shapeRank forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). HistOp rep -> ShapeBase SubExp
histShape) [HistOp (Aliases rep)]
ops
forall a. [a] -> [a] -> [a]
++ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
nes_ts
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t forall a. Eq a => a -> a -> Bool
== [Type]
ts) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
Text
"SegHist body has return type "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
ts
forall a. Semigroup a => a -> a -> a
<> Text
" but should have type "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
bucket_ret_t
where
segment_dims :: [SubExp]
segment_dims = forall a. [a] -> [a]
init forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
checkScanRed ::
TC.Checkable rep =>
SegSpace ->
[(Lambda (Aliases rep), [SubExp], Shape)] ->
[Type] ->
KernelBody (Aliases rep) ->
TC.TypeM rep ()
checkScanRed :: forall {k} (rep :: k).
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
ops [Type]
ts KernelBody (Aliases rep)
kbody = do
forall {k} (rep :: k). Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace SegSpace
space
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {k} (rep :: k) u.
Checkable rep =>
TypeBase (ShapeBase SubExp) u -> TypeM rep ()
TC.checkType [Type]
ts
forall {k} (rep :: k) a.
Checkable rep =>
Scope (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.binding (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$ do
[[Type]]
ne_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
ops forall a b. (a -> b) -> a -> b
$ \(Lambda (Aliases rep)
lam, [SubExp]
nes, ShapeBase SubExp
shape) -> do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
[Arg]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k). Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes
forall {k} (rep :: k).
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases forall a b. (a -> b) -> a -> b
$ [Arg]
nes' forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
let nes_t :: [Type]
nes_t = forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam forall a. Eq a => a -> a -> Bool
== [Type]
nes_t) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError Text
"wrong type for operator or neutral elements."
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) [Type]
nes_t
let expecting :: [Type]
expecting = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
ne_ts
got :: [Type]
got = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
expecting) [Type]
ts
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
expecting forall a. Eq a => a -> a -> Bool
== [Type]
got) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
Text
"Wrong return for body (does not match neutral elements; expected "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText [Type]
expecting
forall a. Semigroup a => a -> a -> a
<> Text
"; found "
forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText [Type]
got
forall a. Semigroup a => a -> a -> a
<> Text
")"
forall {k} (rep :: k).
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts KernelBody (Aliases rep)
kbody
data SegOpMapper lvl frep trep m = SegOpMapper
{ forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp :: SubExp -> m SubExp,
forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda :: Lambda frep -> m (Lambda trep),
forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody :: KernelBody frep -> m (KernelBody trep),
forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName :: VName -> m VName,
forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel :: lvl -> m lvl
}
identitySegOpMapper :: Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper :: forall {k} (m :: * -> *) lvl (rep :: k).
Monad m =>
SegOpMapper lvl rep rep m
identitySegOpMapper =
SegOpMapper
{ mapOnSegOpSubExp :: SubExp -> m SubExp
mapOnSegOpSubExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnSegOpLambda :: Lambda rep -> m (Lambda rep)
mapOnSegOpLambda = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnSegOpBody :: KernelBody rep -> m (KernelBody rep)
mapOnSegOpBody = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnSegOpVName :: VName -> m VName
mapOnSegOpVName = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnSegOpLevel :: lvl -> m lvl
mapOnSegOpLevel = forall (f :: * -> *) a. Applicative f => a -> f a
pure
}
mapOnSegSpace ::
Monad f => SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace :: forall {k} {k} (f :: * -> *) lvl (frep :: k) (trep :: k).
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep f
tv (SegSpace VName
phys [(VName, SubExp)]
dims) =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep f
tv VName
phys
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse (forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep f
tv) (forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep f
tv)) [(VName, SubExp)]
dims
mapSegBinOp ::
Monad m =>
SegOpMapper lvl frep trep m ->
SegBinOp frep ->
m (SegBinOp trep)
mapSegBinOp :: forall {k} {k} (m :: * -> *) lvl (frep :: k) (trep :: k).
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv (SegBinOp Commutativity
comm Lambda frep
red_op [SubExp]
nes ShapeBase SubExp
shape) =
forall {k} (rep :: k).
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda SegOpMapper lvl frep trep m
tv Lambda frep
red_op
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [SubExp]
nes
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall d. [d] -> ShapeBase d
Shape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) (forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape))
mapSegOpM ::
Monad m =>
SegOpMapper lvl frep trep m ->
SegOp lvl frep ->
m (SegOp lvl trep)
mapSegOpM :: forall {k} {k} (m :: * -> *) lvl (frep :: k) (trep :: k).
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl frep trep m
tv (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody frep
body) =
forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} {k} (f :: * -> *) lvl (frep :: k) (trep :: k).
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} (m :: * -> *) lvl (frep :: k) (trep :: k).
Monad m =>
SegOpMapper lvl frep trep m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl frep trep m
tv) [Type]
ts
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
mapSegOpM SegOpMapper lvl frep trep m
tv (SegRed lvl
lvl SegSpace
space [SegBinOp frep]
reds [Type]
ts KernelBody frep
lam) =
forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} {k} (f :: * -> *) lvl (frep :: k) (trep :: k).
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} (m :: * -> *) lvl (frep :: k) (trep :: k).
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv) [SegBinOp frep]
reds
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType forall a b. (a -> b) -> a -> b
$ forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
lam
mapSegOpM SegOpMapper lvl frep trep m
tv (SegScan lvl
lvl SegSpace
space [SegBinOp frep]
scans [Type]
ts KernelBody frep
body) =
forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} {k} (f :: * -> *) lvl (frep :: k) (trep :: k).
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} (m :: * -> *) lvl (frep :: k) (trep :: k).
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv) [SegBinOp frep]
scans
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType forall a b. (a -> b) -> a -> b
$ forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
mapSegOpM SegOpMapper lvl frep trep m
tv (SegHist lvl
lvl SegSpace
space [HistOp frep]
ops [Type]
ts KernelBody frep
body) =
forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} {k} (f :: * -> *) lvl (frep :: k) (trep :: k).
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp frep -> m (HistOp trep)
onHistOp [HistOp frep]
ops
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType forall a b. (a -> b) -> a -> b
$ forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
where
onHistOp :: HistOp frep -> m (HistOp trep)
onHistOp (HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
shape Lambda frep
op) =
forall {k} (rep :: k).
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda rep
-> HistOp rep
HistOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) ShapeBase SubExp
w
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv SubExp
rf
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep m
tv) [VName]
arrs
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [SubExp]
nes
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall d. [d] -> ShapeBase d
Shape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) (forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda SegOpMapper lvl frep trep m
tv Lambda frep
op
mapOnSegOpType ::
Monad m =>
SegOpMapper lvl frep trep m ->
Type ->
m Type
mapOnSegOpType :: forall {k} {k} (m :: * -> *) lvl (frep :: k) (trep :: k).
Monad m =>
SegOpMapper lvl frep trep m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl frep trep m
_tv t :: Type
t@Prim {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t
mapOnSegOpType SegOpMapper lvl frep trep m
tv (Acc VName
acc ShapeBase SubExp
ispace [Type]
ts NoUniqueness
u) =
forall shape u.
VName -> ShapeBase SubExp -> [Type] -> u -> TypeBase shape u
Acc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep m
tv VName
acc
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) ShapeBase SubExp
ispace
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv)) forall (f :: * -> *) a. Applicative f => a -> f a
pure) [Type]
ts
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
mapOnSegOpType SegOpMapper lvl frep trep m
tv (Array PrimType
et ShapeBase SubExp
shape NoUniqueness
u) =
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) ShapeBase SubExp
shape forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
mapOnSegOpType SegOpMapper lvl frep trep m
_tv (Mem Space
s) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall shape u. Space -> TypeBase shape u
Mem Space
s
traverseSegOpStms :: Monad m => OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms :: forall {k} (m :: * -> *) lvl (rep :: k).
Monad m =>
OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms Scope rep -> Stms rep -> m (Stms rep)
f SegOp lvl rep
segop = forall {k} {k} (m :: * -> *) lvl (frep :: k) (trep :: k).
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep m
mapper SegOp lvl rep
segop
where
seg_scope :: Scope rep
seg_scope = forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace (forall {k} lvl (rep :: k). SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
segop)
f' :: Scope rep -> Stms rep -> m (Stms rep)
f' Scope rep
scope = Scope rep -> Stms rep -> m (Stms rep)
f (Scope rep
seg_scope forall a. Semigroup a => a -> a -> a
<> Scope rep
scope)
mapper :: SegOpMapper lvl rep rep m
mapper =
forall {k} (m :: * -> *) lvl (rep :: k).
Monad m =>
SegOpMapper lvl rep rep m
identitySegOpMapper
{ mapOnSegOpLambda :: Lambda rep -> m (Lambda rep)
mapOnSegOpLambda = forall {k} (m :: * -> *) (rep :: k).
Monad m =>
OpStmsTraverser m (Lambda rep) rep
traverseLambdaStms Scope rep -> Stms rep -> m (Stms rep)
f',
mapOnSegOpBody :: KernelBody rep -> m (KernelBody rep)
mapOnSegOpBody = KernelBody rep -> m (KernelBody rep)
onBody
}
onBody :: KernelBody rep -> m (KernelBody rep)
onBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope rep -> Stms rep -> m (Stms rep)
f Scope rep
seg_scope Stms rep
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res
instance
(ASTRep rep, Substitute lvl) =>
Substitute (SegOp lvl rep)
where
substituteNames :: Map VName VName -> SegOp lvl rep -> SegOp lvl rep
substituteNames Map VName VName
subst = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} {k} (m :: * -> *) lvl (frep :: k) (trep :: k).
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep Identity
substitute
where
substitute :: SegOpMapper lvl rep rep Identity
substitute =
SegOpMapper
{ mapOnSegOpSubExp :: SubExp -> Identity SubExp
mapOnSegOpSubExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpLambda :: Lambda rep -> Identity (Lambda rep)
mapOnSegOpLambda = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpBody :: KernelBody rep -> Identity (KernelBody rep)
mapOnSegOpBody = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpVName :: VName -> Identity VName
mapOnSegOpVName = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpLevel :: lvl -> Identity lvl
mapOnSegOpLevel = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
}
instance (ASTRep rep, ASTConstraints lvl) => Rename (SegOp lvl rep) where
rename :: SegOp lvl rep -> RenameM (SegOp lvl rep)
rename SegOp lvl rep
op =
forall a. [VName] -> RenameM a -> RenameM a
renameBound (forall k a. Map k a -> [k]
M.keys (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace (forall {k} lvl (rep :: k). SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
op))) forall a b. (a -> b) -> a -> b
$ forall {k} {k} (m :: * -> *) lvl (frep :: k) (trep :: k).
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep RenameM
renamer SegOp lvl rep
op
where
renamer :: SegOpMapper lvl rep rep RenameM
renamer = forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper forall a. Rename a => a -> RenameM a
rename forall a. Rename a => a -> RenameM a
rename forall a. Rename a => a -> RenameM a
rename forall a. Rename a => a -> RenameM a
rename forall a. Rename a => a -> RenameM a
rename
instance (ASTRep rep, FreeIn lvl) => FreeIn (SegOp lvl rep) where
freeIn' :: SegOp lvl rep -> FV
freeIn' SegOp lvl rep
e =
Names -> FV -> FV
fvBind ([VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace (forall {k} lvl (rep :: k). SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
e)) forall a b. (a -> b) -> a -> b
$
forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s a. State s a -> s -> s
execState forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$
forall {k} {k} (m :: * -> *) lvl (frep :: k) (trep :: k).
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep (StateT FV Identity)
free SegOp lvl rep
e
where
walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Semigroup a => a -> a -> a
<> b -> s
f b
x) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure b
x
free :: SegOpMapper lvl rep rep (StateT FV Identity)
free =
SegOpMapper
{ mapOnSegOpSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSegOpSubExp = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpLambda :: Lambda rep -> StateT FV Identity (Lambda rep)
mapOnSegOpLambda = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpBody :: KernelBody rep -> StateT FV Identity (KernelBody rep)
mapOnSegOpBody = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpVName :: VName -> StateT FV Identity VName
mapOnSegOpVName = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpLevel :: lvl -> StateT FV Identity lvl
mapOnSegOpLevel = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn'
}
instance OpMetrics (Op rep) => OpMetrics (SegOp lvl rep) where
opMetrics :: SegOp lvl rep -> MetricsM ()
opMetrics (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegMap" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
OpMetrics (Op rep) =>
KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
opMetrics (SegRed lvl
_ SegSpace
_ [SegBinOp rep]
reds [Type]
_ KernelBody rep
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegRed" forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} (rep :: k).
OpMetrics (Op rep) =>
Lambda rep -> MetricsM ()
lambdaMetrics forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp rep]
reds
forall {k} (rep :: k).
OpMetrics (Op rep) =>
KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
opMetrics (SegScan lvl
_ SegSpace
_ [SegBinOp rep]
scans [Type]
_ KernelBody rep
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegScan" forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} (rep :: k).
OpMetrics (Op rep) =>
Lambda rep -> MetricsM ()
lambdaMetrics forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp rep]
scans
forall {k} (rep :: k).
OpMetrics (Op rep) =>
KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
opMetrics (SegHist lvl
_ SegSpace
_ [HistOp rep]
ops [Type]
_ KernelBody rep
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegHist" forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} (rep :: k).
OpMetrics (Op rep) =>
Lambda rep -> MetricsM ()
lambdaMetrics forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp) [HistOp rep]
ops
forall {k} (rep :: k).
OpMetrics (Op rep) =>
KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
instance Pretty SegSpace where
pretty :: forall ann. SegSpace -> Doc ann
pretty (SegSpace VName
phys [(VName, SubExp)]
dims) =
forall a. [Doc a] -> Doc a
apply
( do
(VName
i, SubExp
d) <- [(VName, SubExp)]
dims
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a ann. Pretty a => a -> Doc ann
pretty VName
i forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"<" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
d
)
forall a. Doc a -> Doc a -> Doc a
<+> forall ann. Doc ann -> Doc ann
parens (Doc ann
"~" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty VName
phys)
instance PrettyRep rep => Pretty (SegBinOp rep) where
pretty :: forall ann. SegBinOp rep -> Doc ann
pretty (SegBinOp Commutativity
comm Lambda rep
lam [SubExp]
nes ShapeBase SubExp
shape) =
forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
PP.commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
nes) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty ShapeBase SubExp
shape forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
comm' forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
lam
where
comm' :: Doc ann
comm' = case Commutativity
comm of
Commutativity
Commutative -> Doc ann
"commutative "
Commutativity
Noncommutative -> forall a. Monoid a => a
mempty
instance (PrettyRep rep, PP.Pretty lvl) => PP.Pretty (SegOp lvl rep) where
pretty :: forall ann. SegOp lvl rep -> Doc ann
pretty (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody rep
body) =
Doc ann
"segmap" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.align (forall a ann. Pretty a => a -> Doc ann
pretty SegSpace
space)
forall a. Doc a -> Doc a -> Doc a
<+> forall ann. Doc ann
PP.colon
forall a. Doc a -> Doc a -> Doc a
<+> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Type]
ts)
forall a. Doc a -> Doc a -> Doc a
<+> forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (forall a ann. Pretty a => a -> Doc ann
pretty KernelBody rep
body)
pretty (SegRed lvl
lvl SegSpace
space [SegBinOp rep]
reds [Type]
ts KernelBody rep
body) =
Doc ann
"segred" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.align (forall a ann. Pretty a => a -> Doc ann
pretty SegSpace
space)
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.parens (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
PP.comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SegBinOp rep]
reds)
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann
PP.colon
forall a. Doc a -> Doc a -> Doc a
<+> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Type]
ts)
forall a. Doc a -> Doc a -> Doc a
<+> forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (forall a ann. Pretty a => a -> Doc ann
pretty KernelBody rep
body)
pretty (SegScan lvl
lvl SegSpace
space [SegBinOp rep]
scans [Type]
ts KernelBody rep
body) =
Doc ann
"segscan" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.align (forall a ann. Pretty a => a -> Doc ann
pretty SegSpace
space)
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.parens (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
PP.comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SegBinOp rep]
scans)
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann
PP.colon
forall a. Doc a -> Doc a -> Doc a
<+> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Type]
ts)
forall a. Doc a -> Doc a -> Doc a
<+> forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (forall a ann. Pretty a => a -> Doc ann
pretty KernelBody rep
body)
pretty (SegHist lvl
lvl SegSpace
space [HistOp rep]
ops [Type]
ts KernelBody rep
body) =
Doc ann
"seghist" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.align (forall a ann. Pretty a => a -> Doc ann
pretty SegSpace
space)
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.parens (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
PP.comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} {rep :: k} {ann}. PrettyRep rep => HistOp rep -> Doc ann
ppOp [HistOp rep]
ops)
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann
PP.colon
forall a. Doc a -> Doc a -> Doc a
<+> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Type]
ts)
forall a. Doc a -> Doc a -> Doc a
<+> forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (forall a ann. Pretty a => a -> Doc ann
pretty KernelBody rep
body)
where
ppOp :: HistOp rep -> Doc ann
ppOp (HistOp ShapeBase SubExp
w SubExp
rf [VName]
dests [SubExp]
nes ShapeBase SubExp
shape Lambda rep
op) =
forall a ann. Pretty a => a -> Doc ann
pretty ShapeBase SubExp
w forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
rf forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
PP.commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [VName]
dests) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
PP.commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
nes) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty ShapeBase SubExp
shape forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
op
instance
( ASTRep rep,
ASTRep (Aliases rep),
CanBeAliased (Op rep),
ASTConstraints lvl
) =>
CanBeAliased (SegOp lvl rep)
where
type OpWithAliases (SegOp lvl rep) = SegOp lvl (Aliases rep)
addOpAliases :: AliasTable -> SegOp lvl rep -> OpWithAliases (SegOp lvl rep)
addOpAliases AliasTable
aliases = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} {k} (m :: * -> *) lvl (frep :: k) (trep :: k).
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep (Aliases rep) Identity
alias
where
alias :: SegOpMapper lvl rep (Aliases rep) Identity
alias =
forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
forall (f :: * -> *) a. Applicative f => a -> f a
pure
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases)
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
aliases)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
forall (f :: * -> *) a. Applicative f => a -> f a
pure
removeOpAliases :: OpWithAliases (SegOp lvl rep) -> SegOp lvl rep
removeOpAliases = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} {k} (m :: * -> *) lvl (frep :: k) (trep :: k).
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM forall {lvl}. SegOpMapper lvl (Aliases rep) rep Identity
remove
where
remove :: SegOpMapper lvl (Aliases rep) rep Identity
remove =
forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
forall (f :: * -> *) a. Applicative f => a -> f a
pure
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k).
CanBeAliased (Op rep) =>
Lambda (Aliases rep) -> Lambda rep
removeLambdaAliases)
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k).
CanBeAliased (Op rep) =>
KernelBody (Aliases rep) -> KernelBody rep
removeKernelBodyAliases)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
forall (f :: * -> *) a. Applicative f => a -> f a
pure
informKernelBody :: Informing rep => KernelBody rep -> KernelBody (Wise rep)
informKernelBody :: forall {k} (rep :: k).
Informing rep =>
KernelBody rep -> KernelBody (Wise rep)
informKernelBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
forall {k} (rep :: k).
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody BodyDec rep
dec (forall {k} (rep :: k). Informing rep => Stms rep -> Stms (Wise rep)
informStms Stms rep
stms) [KernelResult]
res
instance
(CanBeWise (Op rep), ASTRep rep, ASTConstraints lvl) =>
CanBeWise (SegOp lvl rep)
where
type OpWithWisdom (SegOp lvl rep) = SegOp lvl (Wise rep)
removeOpWisdom :: OpWithWisdom (SegOp lvl rep) -> SegOp lvl rep
removeOpWisdom = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} {k} (m :: * -> *) lvl (frep :: k) (trep :: k).
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM forall {lvl}. SegOpMapper lvl (Wise rep) rep Identity
remove
where
remove :: SegOpMapper lvl (Wise rep) rep Identity
remove =
forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
forall (f :: * -> *) a. Applicative f => a -> f a
pure
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k).
CanBeWise (Op rep) =>
Lambda (Wise rep) -> Lambda rep
removeLambdaWisdom)
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k).
CanBeWise (Op rep) =>
KernelBody (Wise rep) -> KernelBody rep
removeKernelBodyWisdom)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
forall (f :: * -> *) a. Applicative f => a -> f a
pure
addOpWisdom :: SegOp lvl rep -> OpWithWisdom (SegOp lvl rep)
addOpWisdom = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} {k} (m :: * -> *) lvl (frep :: k) (trep :: k).
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM forall {lvl}. SegOpMapper lvl rep (Wise rep) Identity
add
where
add :: SegOpMapper lvl rep (Wise rep) Identity
add =
forall {k} {k} lvl (frep :: k) (trep :: k) (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
forall (f :: * -> *) a. Applicative f => a -> f a
pure
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k).
Informing rep =>
Lambda rep -> Lambda (Wise rep)
informLambda)
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k).
Informing rep =>
KernelBody rep -> KernelBody (Wise rep)
informKernelBody)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
forall (f :: * -> *) a. Applicative f => a -> f a
pure
instance ASTRep rep => ST.IndexOp (SegOp lvl rep) where
indexOp :: forall {k} (rep :: k).
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> SegOp lvl rep -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k (SegMap lvl
_ SegSpace
space [Type]
_ KernelBody rep
kbody) [TPrimExp Int64 VName]
is = do
Returns ResultManifest
ResultMaySimplify Certs
_ SubExp
se <- forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
k forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids forall a. Ord a => a -> a -> Bool
<= forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
is
let idx_table :: Map VName Indexed
idx_table = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Certs -> PrimExp VName -> Indexed
ST.Indexed forall a. Monoid a => a
mempty forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) [TPrimExp Int64 VName]
is
idx_table' :: Map VName Indexed
idx_table' = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map VName Indexed -> Stm rep -> Map VName Indexed
expandIndexedTable Map VName Indexed
idx_table forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
kbody
case SubExp
se of
Var VName
v -> forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
idx_table'
SubExp
_ -> forall a. Maybe a
Nothing
where
([VName]
gtids, [SubExp]
_) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
excess_is :: [TPrimExp Int64 VName]
excess_is = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) [TPrimExp Int64 VName]
is
expandIndexedTable :: Map VName Indexed -> Stm rep -> Map VName Indexed
expandIndexedTable Map VName Indexed
table Stm rep
stm
| [VName
v] <- forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
Just (PrimExp VName
pe, Certs
cs) <-
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall {k} (m :: * -> *) (rep :: k) v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm =
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Certs -> PrimExp VName -> Indexed
ST.Indexed (forall {k} (rep :: k). Stm rep -> Certs
stmCerts Stm rep
stm forall a. Semigroup a => a -> a -> a
<> Certs
cs) PrimExp VName
pe) Map VName Indexed
table
| [VName
v] <- forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
BasicOp (Index VName
arr Slice SubExp
slice) <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm,
forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
excess_is,
VName
arr forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
`ST.available` SymbolTable rep
vtable,
Just (Slice (PrimExp VName)
slice', Certs
cs) <- Map VName Indexed
-> Slice SubExp -> Maybe (Slice (PrimExp VName), Certs)
asPrimExpSlice Map VName Indexed
table Slice SubExp
slice =
let idx :: Indexed
idx =
Certs -> VName -> [TPrimExp Int64 VName] -> Indexed
ST.IndexedArray
(forall {k} (rep :: k). Stm rep -> Certs
stmCerts Stm rep
stm forall a. Semigroup a => a -> a -> a
<> Certs
cs)
VName
arr
(forall d. Num d => Slice d -> [d] -> [d]
fixSlice (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall v. PrimExp v -> TPrimExp Int64 v
isInt64 Slice (PrimExp VName)
slice') [TPrimExp Int64 VName]
excess_is)
in forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Indexed
idx Map VName Indexed
table
| Bool
otherwise =
Map VName Indexed
table
asPrimExpSlice :: Map VName Indexed
-> Slice SubExp -> Maybe (Slice (PrimExp VName), Certs)
asPrimExpSlice Map VName Indexed
table =
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM (Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table))
asPrimExp :: Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table VName
v
| Just (ST.Indexed Certs
cs PrimExp VName
e) <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
table = forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certs
cs forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimExp VName
e
| Just (Prim PrimType
pt) <- forall {k} (rep :: k).
ASTRep rep =>
VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
| Bool
otherwise = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a. Maybe a
Nothing
indexOp SymbolTable rep
_ Int
_ SegOp lvl rep
_ [TPrimExp Int64 VName]
_ = forall a. Maybe a
Nothing
instance
(ASTRep rep, ASTConstraints lvl) =>
IsOp (SegOp lvl rep)
where
cheapOp :: SegOp lvl rep -> Bool
cheapOp SegOp lvl rep
_ = Bool
False
safeOp :: SegOp lvl rep -> Bool
safeOp SegOp lvl rep
_ = Bool
True
instance Engine.Simplifiable SegSpace where
simplify :: forall {k} (rep :: k).
SimplifiableRep rep =>
SegSpace -> SimpleM rep SegSpace
simplify (SegSpace VName
phys [(VName, SubExp)]
dims) =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
phys forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify) [(VName, SubExp)]
dims
instance Engine.Simplifiable KernelResult where
simplify :: forall {k} (rep :: k).
SimplifiableRep rep =>
KernelResult -> SimpleM rep KernelResult
simplify (Returns ResultManifest
manifest Certs
cs SubExp
what) =
ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
what
simplify (WriteReturns Certs
cs ShapeBase SubExp
ws VName
a [(Slice SubExp, SubExp)]
res) =
Certs
-> ShapeBase SubExp
-> VName
-> [(Slice SubExp, SubExp)]
-> KernelResult
WriteReturns
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
ws
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
a
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(Slice SubExp, SubExp)]
res
simplify (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
what) =
Certs -> [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(SubExp, SubExp)]
dims forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
what
simplify (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
what) =
Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(SubExp, SubExp, SubExp)]
dims_n_tiles
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
what
mkWiseKernelBody ::
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep ->
Stms (Wise rep) ->
[KernelResult] ->
KernelBody (Wise rep)
mkWiseKernelBody :: forall {k} (rep :: k).
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody BodyDec rep
dec Stms (Wise rep)
stms [KernelResult]
res =
let Body BodyDec (Wise rep)
dec' Stms (Wise rep)
_ Result
_ = forall {k} (rep :: k).
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
mkWiseBody BodyDec rep
dec Stms (Wise rep)
stms forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
res_vs
in forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Wise rep)
dec' Stms (Wise rep)
stms [KernelResult]
res
where
res_vs :: [SubExp]
res_vs = forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
res
mkKernelBodyM ::
MonadBuilder m =>
Stms (Rep m) ->
[KernelResult] ->
m (KernelBody (Rep m))
mkKernelBodyM :: forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms (Rep m)
stms [KernelResult]
kres = do
Body BodyDec (Rep m)
dec' Stms (Rep m)
_ Result
_ <- forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
stms forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
res_ses
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Rep m)
dec' Stms (Rep m)
stms [KernelResult]
kres
where
res_ses :: [SubExp]
res_ses = forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
kres
simplifyKernelBody ::
(Engine.SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace ->
KernelBody (Wise rep) ->
Engine.SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody :: forall {k} (rep :: k).
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space (KernelBody BodyDec (Wise rep)
_ Stms (Wise rep)
stms [KernelResult]
res) = do
BlockPred (Wise rep)
par_blocker <- forall {k} (rep :: k) a. (Env rep -> a) -> SimpleM rep a
Engine.asksEngineEnv forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). HoistBlockers rep -> BlockPred (Wise rep)
Engine.blockHoistPar forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Env rep -> HoistBlockers rep
Engine.envHoistBlockers
let blocker :: BlockPred (Wise rep)
blocker =
forall {k} (rep :: k). ASTRep rep => Names -> BlockPred rep
Engine.hasFree Names
bound_here
forall {k} (rep :: k).
BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` forall {k} (rep :: k). BlockPred rep
Engine.isOp
forall {k} (rep :: k).
BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
par_blocker
forall {k} (rep :: k).
BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` forall {k} (rep :: k). BlockPred rep
Engine.isConsumed
forall {k} (rep :: k).
BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` forall {k} (rep :: k). Aliased rep => BlockPred rep
Engine.isConsuming
forall {k} (rep :: k).
BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` forall {k} (rep :: k). SimplifiableRep rep => BlockPred (Wise rep)
Engine.isDeviceMigrated
([KernelResult]
body_res, Stms (Wise rep)
body_stms, Stms (Wise rep)
hoisted) <-
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (forall a b c. (a -> b -> c) -> b -> a -> c
flip (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall {k} (rep :: k). VName -> SymbolTable rep -> SymbolTable rep
ST.consume)) (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap KernelResult -> [VName]
consumedInResult [KernelResult]
res))
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True})
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop
forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) a.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep (a, UsageTable)
-> SimpleM rep (a, Stms (Wise rep), Stms (Wise rep))
Engine.blockIf BlockPred (Wise rep)
blocker Stms (Wise rep)
stms
forall a b. (a -> b) -> a -> b
$ do
[KernelResult]
res' <-
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (forall {k} (rep :: k). Names -> SymbolTable rep -> SymbolTable rep
ST.hideCertified forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stms (Wise rep)
stms) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [KernelResult]
res
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([KernelResult]
res', Names -> UsageTable
UT.usages forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn [KernelResult]
res')
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody () Stms (Wise rep)
body_stms [KernelResult]
body_res, Stms (Wise rep)
hoisted)
where
scope_vtable :: SymbolTable (Wise rep)
scope_vtable = forall {k} (rep :: k). ASTRep rep => SegSpace -> SymbolTable rep
segSpaceSymbolTable SegSpace
space
bound_here :: Names
bound_here = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
consumedInResult :: KernelResult -> [VName]
consumedInResult (WriteReturns Certs
_ ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) =
[VName
arr]
consumedInResult KernelResult
_ =
[]
simplifyLambda ::
Engine.SimplifiableRep rep =>
Lambda (Wise rep) ->
Engine.SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda :: forall {k} (rep :: k).
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda = forall {k} (rep :: k).
SimplifiableRep rep =>
SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.blockMigrated forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k).
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda
segSpaceSymbolTable :: ASTRep rep => SegSpace -> ST.SymbolTable rep
segSpaceSymbolTable :: forall {k} (rep :: k). ASTRep rep => SegSpace -> SymbolTable rep
segSpaceSymbolTable (SegSpace VName
flat [(VName, SubExp)]
gtids_and_dims) =
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {k} {rep :: k}.
ASTRep rep =>
SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
f (forall {k} (rep :: k). ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope forall a b. (a -> b) -> a -> b
$ forall k a. k -> a -> Map k a
M.singleton VName
flat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). IntType -> NameInfo rep
IndexName IntType
Int64) [(VName, SubExp)]
gtids_and_dims
where
f :: SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
f SymbolTable rep
vtable (VName
gtid, SubExp
dim) = forall {k} (rep :: k).
ASTRep rep =>
VName -> IntType -> SubExp -> SymbolTable rep -> SymbolTable rep
ST.insertLoopVar VName
gtid IntType
Int64 SubExp
dim SymbolTable rep
vtable
simplifySegBinOp ::
Engine.SimplifiableRep rep =>
SegBinOp (Wise rep) ->
Engine.SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp :: forall {k} (rep :: k).
SimplifiableRep rep =>
SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp (SegBinOp Commutativity
comm Lambda (Wise rep)
lam [SubExp]
nes ShapeBase SubExp
shape) = do
(Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <-
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True}) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda Lambda (Wise rep)
lam
ShapeBase SubExp
shape' <- forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
shape
[SubExp]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm Lambda (Wise rep)
lam' [SubExp]
nes' ShapeBase SubExp
shape', Stms (Wise rep)
hoisted)
simplifySegOp ::
( Engine.SimplifiableRep rep,
BodyDec rep ~ (),
Engine.Simplifiable lvl
) =>
SegOp lvl (Wise rep) ->
Engine.SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp :: forall {k} (rep :: k) lvl.
(SimplifiableRep rep, BodyDec rep ~ (), Simplifiable lvl) =>
SegOp lvl (Wise rep)
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Wise rep)
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
(KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- forall {k} (rep :: k).
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl' SegSpace
space' [Type]
ts' KernelBody (Wise rep)
kbody',
Stms (Wise rep)
body_hoisted
)
simplifySegOp (SegRed lvl
lvl SegSpace
space [SegBinOp (Wise rep)]
reds [Type]
ts KernelBody (Wise rep)
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
([SegBinOp (Wise rep)]
reds', [Stms (Wise rep)]
reds_hoisted) <-
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) forall a b. (a -> b) -> a -> b
$
forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k).
SimplifiableRep rep =>
SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp [SegBinOp (Wise rep)]
reds
(KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- forall {k} (rep :: k).
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl' SegSpace
space' [SegBinOp (Wise rep)]
reds' [Type]
ts' KernelBody (Wise rep)
kbody',
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
reds_hoisted forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
)
where
scope :: Scope (Wise rep)
scope = forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
scope_vtable :: SymbolTable (Wise rep)
scope_vtable = forall {k} (rep :: k). ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
simplifySegOp (SegScan lvl
lvl SegSpace
space [SegBinOp (Wise rep)]
scans [Type]
ts KernelBody (Wise rep)
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
([SegBinOp (Wise rep)]
scans', [Stms (Wise rep)]
scans_hoisted) <-
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) forall a b. (a -> b) -> a -> b
$
forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k).
SimplifiableRep rep =>
SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp [SegBinOp (Wise rep)]
scans
(KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- forall {k} (rep :: k).
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl' SegSpace
space' [SegBinOp (Wise rep)]
scans' [Type]
ts' KernelBody (Wise rep)
kbody',
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
scans_hoisted forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
)
where
scope :: Scope (Wise rep)
scope = forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
scope_vtable :: SymbolTable (Wise rep)
scope_vtable = forall {k} (rep :: k). ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
simplifySegOp (SegHist lvl
lvl SegSpace
space [HistOp (Wise rep)]
ops [Type]
ts KernelBody (Wise rep)
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
([HistOp (Wise rep)]
ops', [Stms (Wise rep)]
ops_hoisted) <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Wise rep)]
ops forall a b. (a -> b) -> a -> b
$
\(HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
dims Lambda (Wise rep)
lam) -> do
ShapeBase SubExp
w' <- forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
w
SubExp
rf' <- forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
rf
[VName]
arrs' <- forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
arrs
[SubExp]
nes' <- forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
ShapeBase SubExp
dims' <- forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
dims
(Lambda (Wise rep)
lam', Stms (Wise rep)
op_hoisted) <-
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True}) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda Lambda (Wise rep)
lam
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall {k} (rep :: k).
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda rep
-> HistOp rep
HistOp ShapeBase SubExp
w' SubExp
rf' [VName]
arrs' [SubExp]
nes' ShapeBase SubExp
dims' Lambda (Wise rep)
lam',
Stms (Wise rep)
op_hoisted
)
(KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- forall {k} (rep :: k).
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl' SegSpace
space' [HistOp (Wise rep)]
ops' [Type]
ts' KernelBody (Wise rep)
kbody',
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
ops_hoisted forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
)
where
scope :: Scope (Wise rep)
scope = forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
scope_vtable :: SymbolTable (Wise rep)
scope_vtable = forall {k} (rep :: k). ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
class HasSegOp rep where
type SegOpLevel rep
asSegOp :: Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
segOp :: SegOp (SegOpLevel rep) rep -> Op rep
segOpRules ::
(HasSegOp rep, BuilderOps rep, Buildable rep, Aliased rep) =>
RuleBook rep
segOpRules :: forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep, Aliased rep) =>
RuleBook rep
segOpRules =
forall {k} (m :: k).
[TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
TopDownRuleOp rep
segOpRuleTopDown] [forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(HasSegOp rep, BuilderOps rep, Aliased rep) =>
BottomUpRuleOp rep
segOpRuleBottomUp]
segOpRuleTopDown ::
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
TopDownRuleOp rep
segOpRuleTopDown :: forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
TopDownRuleOp rep
segOpRuleTopDown TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec Op rep
op
| Just SegOp (SegOpLevel rep) rep
op' <- forall {k} (rep :: k).
HasSegOp rep =>
Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
asSegOp Op rep
op =
forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
SymbolTable rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
topDownSegOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
op'
| Bool
otherwise =
forall {k} (rep :: k). Rule rep
Skip
segOpRuleBottomUp ::
(HasSegOp rep, BuilderOps rep, Aliased rep) =>
BottomUpRuleOp rep
segOpRuleBottomUp :: forall rep.
(HasSegOp rep, BuilderOps rep, Aliased rep) =>
BottomUpRuleOp rep
segOpRuleBottomUp BottomUp rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec Op rep
op
| Just SegOp (SegOpLevel rep) rep
op' <- forall {k} (rep :: k).
HasSegOp rep =>
Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
asSegOp Op rep
op =
forall rep.
(Aliased rep, HasSegOp rep, BuilderOps rep) =>
(SymbolTable rep, UsageTable)
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
bottomUpSegOp BottomUp rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
op'
| Bool
otherwise =
forall {k} (rep :: k). Rule rep
Skip
topDownSegOp ::
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
ST.SymbolTable rep ->
Pat (LetDec rep) ->
StmAux (ExpDec rep) ->
SegOp (SegOpLevel rep) rep ->
Rule rep
topDownSegOp :: forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
SymbolTable rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
topDownSegOp SymbolTable rep
vtable (Pat [PatElem (LetDec rep)]
kpes) StmAux (ExpDec rep)
dec (SegMap SegOpLevel rep
lvl SegSpace
space [Type]
ts (KernelBody BodyDec rep
_ Stms rep
kstms [KernelResult]
kres)) = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
([Type]
ts', [PatElem (LetDec rep)]
kpes', [KernelResult]
kres') <-
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool
checkForInvarianceResult (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
ts [PatElem (LetDec rep)]
kpes [KernelResult]
kres)
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([KernelResult]
kres forall a. Eq a => a -> a -> Bool
== [KernelResult]
kres') forall {k} (rep :: k) a. RuleM rep a
cannotSimplify
KernelBody rep
kbody <- forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms rep
kstms [KernelResult]
kres'
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
kpes') StmAux (ExpDec rep)
dec forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
HasSegOp rep =>
SegOp (SegOpLevel rep) rep -> Op rep
segOp forall a b. (a -> b) -> a -> b
$
forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space [Type]
ts' KernelBody rep
kbody
where
isInvariant :: SubExp -> Bool
isInvariant Constant {} = Bool
True
isInvariant (Var VName
v) = forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
v SymbolTable rep
vtable
checkForInvarianceResult :: (Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool
checkForInvarianceResult (Type
_, PatElem (LetDec rep)
pe, Returns ResultManifest
rm Certs
cs SubExp
se)
| Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty,
ResultManifest
rm forall a. Eq a => a -> a -> Bool
== ResultManifest
ResultMaySimplify,
SubExp -> Bool
isInvariant SubExp
se = do
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
ShapeBase SubExp -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space) SubExp
se
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
checkForInvarianceResult (Type, PatElem (LetDec rep), KernelResult)
_ =
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
topDownSegOp SymbolTable rep
_ (Pat [PatElem (LetDec rep)]
pes) StmAux (ExpDec rep)
_ (SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
ts KernelBody rep
kbody)
| forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegBinOp rep]
ops forall a. Ord a => a -> a -> Bool
> Int
1,
[[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings <-
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy forall {k} {k} {rep :: k} {b} {rep :: k} {b}.
(SegBinOp rep, b) -> (SegBinOp rep, b) -> Bool
sameShape forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp rep]
ops forall a b. (a -> b) -> a -> b
$
forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp rep]
ops) forall a b. (a -> b) -> a -> b
$
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
red_pes [Type]
red_ts [KernelResult]
red_res,
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((forall a. Ord a => a -> a -> Bool
> Int
1) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (t :: * -> *) a. Foldable t => t a -> Int
length) [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
let ([SegBinOp rep]
ops', [[(PatElem (LetDec rep), Type, KernelResult)]]
aux) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {k} {rep :: k} {a}.
Buildable rep =>
[(SegBinOp rep, [a])] -> Maybe (SegBinOp rep, [a])
combineOps [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings
([PatElem (LetDec rep)]
red_pes', [Type]
red_ts', [KernelResult]
red_res') = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(PatElem (LetDec rep), Type, KernelResult)]]
aux
pes' :: [PatElem (LetDec rep)]
pes' = [PatElem (LetDec rep)]
red_pes' forall a. [a] -> [a] -> [a]
++ [PatElem (LetDec rep)]
map_pes
ts' :: [Type]
ts' = [Type]
red_ts' forall a. [a] -> [a] -> [a]
++ [Type]
map_ts
kbody' :: KernelBody rep
kbody' = KernelBody rep
kbody {kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult]
red_res' forall a. [a] -> [a] -> [a]
++ [KernelResult]
map_res}
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pes') forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
HasSegOp rep =>
SegOp (SegOpLevel rep) rep -> Op rep
segOp forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops' [Type]
ts' KernelBody rep
kbody'
where
([PatElem (LetDec rep)]
red_pes, [PatElem (LetDec rep)]
map_pes) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall {k} (rep :: k). [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) [PatElem (LetDec rep)]
pes
([Type]
red_ts, [Type]
map_ts) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall {k} (rep :: k). [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) [Type]
ts
([KernelResult]
red_res, [KernelResult]
map_res) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall {k} (rep :: k). [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
sameShape :: (SegBinOp rep, b) -> (SegBinOp rep, b) -> Bool
sameShape (SegBinOp rep
op1, b
_) (SegBinOp rep
op2, b
_) = forall {k} (rep :: k). SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op1 forall a. Eq a => a -> a -> Bool
== forall {k} (rep :: k). SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op2
combineOps :: [(SegBinOp rep, [a])] -> Maybe (SegBinOp rep, [a])
combineOps [] = forall a. Maybe a
Nothing
combineOps ((SegBinOp rep, [a])
x : [(SegBinOp rep, [a])]
xs) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {k} {rep :: k} {a}.
Buildable rep =>
(SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
combine (SegBinOp rep, [a])
x [(SegBinOp rep, [a])]
xs
combine :: (SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
combine (SegBinOp rep
op1, [a]
op1_aux) (SegBinOp rep
op2, [a]
op2_aux) =
let lam1 :: Lambda rep
lam1 = forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op1
lam2 :: Lambda rep
lam2 = forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op2
([Param (LParamInfo rep)]
op1_xparams, [Param (LParamInfo rep)]
op1_yparams) =
forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op1)) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam1
([Param (LParamInfo rep)]
op2_xparams, [Param (LParamInfo rep)]
op2_yparams) =
forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op2)) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam2
lam :: Lambda rep
lam =
Lambda
{ lambdaParams :: [Param (LParamInfo rep)]
lambdaParams =
[Param (LParamInfo rep)]
op1_xparams
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op2_xparams
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op1_yparams
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op2_yparams,
lambdaReturnType :: [Type]
lambdaReturnType = forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam1 forall a. [a] -> [a] -> [a]
++ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam2,
lambdaBody :: Body rep
lambdaBody =
forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam1) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Body rep -> Stms rep
bodyStms (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam2)) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Body rep -> Result
bodyResult (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam1) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Body rep -> Result
bodyResult (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam2)
}
in ( SegBinOp
{ segBinOpComm :: Commutativity
segBinOpComm = forall {k} (rep :: k). SegBinOp rep -> Commutativity
segBinOpComm SegBinOp rep
op1 forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). SegBinOp rep -> Commutativity
segBinOpComm SegBinOp rep
op2,
segBinOpLambda :: Lambda rep
segBinOpLambda = Lambda rep
lam,
segBinOpNeutral :: [SubExp]
segBinOpNeutral = forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op1 forall a. [a] -> [a] -> [a]
++ forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op2,
segBinOpShape :: ShapeBase SubExp
segBinOpShape = forall {k} (rep :: k). SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op1
},
[a]
op1_aux forall a. [a] -> [a] -> [a]
++ [a]
op2_aux
)
topDownSegOp SymbolTable rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ SegOp (SegOpLevel rep) rep
_ = forall {k} (rep :: k). Rule rep
Skip
segOpGuts ::
SegOp (SegOpLevel rep) rep ->
( [Type],
KernelBody rep,
Int,
[Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
)
segOpGuts :: forall {k} (rep :: k).
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
[Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
segOpGuts (SegMap SegOpLevel rep
lvl SegSpace
space [Type]
kts KernelBody rep
body) =
([Type]
kts, KernelBody rep
body, Int
0, forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space)
segOpGuts (SegScan SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
kts KernelBody rep
body) =
([Type]
kts, KernelBody rep
body, forall {k} (rep :: k). [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops, forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops)
segOpGuts (SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
kts KernelBody rep
body) =
([Type]
kts, KernelBody rep
body, forall {k} (rep :: k). [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops, forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops)
segOpGuts (SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops [Type]
kts KernelBody rep
body) =
([Type]
kts, KernelBody rep
body, forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). HistOp rep -> [VName]
histDest) [HistOp rep]
ops, forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops)
bottomUpSegOp ::
(Aliased rep, HasSegOp rep, BuilderOps rep) =>
(ST.SymbolTable rep, UT.UsageTable) ->
Pat (LetDec rep) ->
StmAux (ExpDec rep) ->
SegOp (SegOpLevel rep) rep ->
Rule rep
bottomUpSegOp :: forall rep.
(Aliased rep, HasSegOp rep, BuilderOps rep) =>
(SymbolTable rep, UsageTable)
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
bottomUpSegOp (SymbolTable rep
vtable, UsageTable
used) (Pat [PatElem (LetDec rep)]
kpes) StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
segop = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') <-
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stm rep
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
distribute ([PatElem (LetDec rep)]
kpes, [Type]
kts, [KernelResult]
kres, forall a. Monoid a => a
mempty) Stms rep
kstms
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
([PatElem (LetDec rep)]
kpes' forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
kpes)
forall {k} (rep :: k) a. RuleM rep a
cannotSimplify
KernelBody rep
kbody' <-
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms rep
kstms' [KernelResult]
kres'
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
kpes') StmAux (ExpDec rep)
dec forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
HasSegOp rep =>
SegOp (SegOpLevel rep) rep -> Op rep
segOp forall a b. (a -> b) -> a -> b
$ [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop [Type]
kts' KernelBody rep
kbody'
where
([Type]
kts, kbody :: KernelBody rep
kbody@(KernelBody BodyDec rep
_ Stms rep
kstms [KernelResult]
kres), Int
num_nonmap_results, [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop) =
forall {k} (rep :: k).
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
[Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
segOpGuts SegOp (SegOpLevel rep) rep
segop
free_in_kstms :: Names
free_in_kstms = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> Names
freeIn Stms rep
kstms
consumed_in_segop :: Names
consumed_in_segop = forall {k} (rep :: k). Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
space :: SegSpace
space = forall {k} lvl (rep :: k). SegOp lvl rep -> SegSpace
segSpace SegOp (SegOpLevel rep) rep
segop
sliceWithGtidsFixed :: Stm rep -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm rep
stm
| Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (BasicOp (Index VName
arr Slice SubExp
slice)) <- Stm rep
stm,
[DimIndex SubExp]
space_slice <- forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> SubExp
Var forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space,
[DimIndex SubExp]
space_slice forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice,
Slice SubExp
remaining_slice <- forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
space_slice) (forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice),
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Maybe a -> Bool
isJust forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup SymbolTable rep
vtable) forall a b. (a -> b) -> a -> b
$
Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$
forall a. FreeIn a => a -> Names
freeIn VName
arr forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn Slice SubExp
remaining_slice =
forall a. a -> Maybe a
Just (Slice SubExp
remaining_slice, VName
arr)
| Bool
otherwise =
forall a. Maybe a
Nothing
distribute :: ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stm rep
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
distribute ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') Stm rep
stm
| Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ Exp rep
_ <- Stm rep
stm,
Just (Slice [DimIndex SubExp]
remaining_slice, VName
arr) <- Stm rep -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm rep
stm,
Just (PatElem (LetDec rep)
kpe, [PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> PatElem (LetDec rep)
-> Maybe
(PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
[KernelResult])
isResult [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' PatElem (LetDec rep)
pe = do
let outer_slice :: [DimIndex SubExp]
outer_slice =
forall a b. (a -> b) -> [a] -> [b]
map
( \SubExp
d ->
forall d. d -> d -> d -> DimIndex d
DimSlice
(forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))
SubExp
d
(forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))
)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
index :: PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe' =
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe'] forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$
forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
[DimIndex SubExp]
outer_slice forall a. Semigroup a => a -> a -> a
<> [DimIndex SubExp]
remaining_slice
if forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe
VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
Bool -> Bool -> Bool
|| VName
arr
VName -> Names -> Bool
`nameIn` Names
consumed_in_segop
then do
VName
precopy <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe) forall a. Semigroup a => a -> a -> a
<> String
"_precopy"
PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe {patElemName :: VName
patElemName = VName
precopy}
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
precopy
else PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( [PatElem (LetDec rep)]
kpes'',
[Type]
kts'',
[KernelResult]
kres'',
if forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> Names -> Bool
`nameIn` Names
free_in_kstms
then Stms rep
kstms' forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm rep
stm
else Stms rep
kstms'
)
distribute ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') Stm rep
stm =
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms' forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm rep
stm)
isResult :: [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> PatElem (LetDec rep)
-> Maybe
(PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
[KernelResult])
isResult [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' PatElem (LetDec rep)
pe =
case forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (PatElem (LetDec rep), Type, KernelResult) -> Bool
matches forall a b. (a -> b) -> a -> b
$ forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' of
([(PatElem (LetDec rep)
kpe, Type
_, KernelResult
_)], [(PatElem (LetDec rep), Type, KernelResult)]
kpes_and_kres)
| Just Int
i <- forall a. Eq a => a -> [a] -> Maybe Int
elemIndex PatElem (LetDec rep)
kpe [PatElem (LetDec rep)]
kpes,
Int
i forall a. Ord a => a -> a -> Bool
>= Int
num_nonmap_results,
([PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(PatElem (LetDec rep), Type, KernelResult)]
kpes_and_kres ->
forall a. a -> Maybe a
Just (PatElem (LetDec rep)
kpe, [PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'')
([(PatElem (LetDec rep), Type, KernelResult)],
[(PatElem (LetDec rep), Type, KernelResult)])
_ -> forall a. Maybe a
Nothing
where
matches :: (PatElem (LetDec rep), Type, KernelResult) -> Bool
matches (PatElem (LetDec rep)
_, Type
_, Returns ResultManifest
_ Certs
_ (Var VName
v)) = VName
v forall a. Eq a => a -> a -> Bool
== forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe
matches (PatElem (LetDec rep), Type, KernelResult)
_ = Bool
False
kernelBodyReturns ::
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep ->
[ExpReturns] ->
m [ExpReturns]
kernelBodyReturns :: forall {k} {k} (rep :: k) inner (m :: * -> *) (somerep :: k).
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns = forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {k} {rep :: k} {inner} {m :: * -> *}.
(FParamInfo rep ~ FParamMem, LParamInfo rep ~ LParamMem,
RetType rep ~ RetTypeMem, BranchType rep ~ BranchTypeMem,
Op rep ~ MemOp inner, Monad m, HasScope rep m,
HasLetDecMem (LetDec rep), ASTRep rep, OpReturns inner) =>
KernelResult -> ExpReturns -> m ExpReturns
correct forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult
where
correct :: KernelResult -> ExpReturns -> m ExpReturns
correct (WriteReturns Certs
_ ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) ExpReturns
_ = forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns VName
arr
correct KernelResult
_ ExpReturns
ret = forall (f :: * -> *) a. Applicative f => a -> f a
pure ExpReturns
ret
segOpReturns ::
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl somerep ->
m [ExpReturns]
segOpReturns :: forall {k} {k} (rep :: k) inner (m :: * -> *) lvl (somerep :: k).
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl somerep -> m [ExpReturns]
segOpReturns k :: SegOp lvl somerep
k@(SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody somerep
kbody) =
forall {k} {k} (rep :: k) inner (m :: * -> *) (somerep :: k).
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody somerep
kbody forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall op {k} (t :: k) (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl somerep
k
segOpReturns k :: SegOp lvl somerep
k@(SegRed lvl
_ SegSpace
_ [SegBinOp somerep]
_ [Type]
_ KernelBody somerep
kbody) =
forall {k} {k} (rep :: k) inner (m :: * -> *) (somerep :: k).
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody somerep
kbody forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall op {k} (t :: k) (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl somerep
k
segOpReturns k :: SegOp lvl somerep
k@(SegScan lvl
_ SegSpace
_ [SegBinOp somerep]
_ [Type]
_ KernelBody somerep
kbody) =
forall {k} {k} (rep :: k) inner (m :: * -> *) (somerep :: k).
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody somerep
kbody forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall op {k} (t :: k) (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl somerep
k
segOpReturns (SegHist lvl
_ SegSpace
_ [HistOp somerep]
ops [Type]
_ KernelBody somerep
_) =
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). HistOp rep -> [VName]
histDest) [HistOp somerep]
ops