{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Futhark.IR.SegOp
( SegOp (..),
segLevel,
segBody,
segSpace,
typeCheckSegOp,
SegSpace (..),
scopeOfSegSpace,
segSpaceDims,
HistOp (..),
histType,
splitHistResults,
SegBinOp (..),
segBinOpResults,
segBinOpChunks,
KernelBody (..),
aliasAnalyseKernelBody,
consumedInKernelBody,
ResultManifest (..),
KernelResult (..),
kernelResultCerts,
kernelResultSubExp,
SegOpMapper (..),
identitySegOpMapper,
mapSegOpM,
traverseSegOpStms,
simplifySegOp,
HasSegOp (..),
segOpRules,
segOpReturns,
)
where
import Control.Category
import Control.Monad
import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State.Strict
import Control.Monad.Writer
import Data.Bifunctor (first)
import Data.Bitraversable
import Data.List
( elemIndex,
foldl',
groupBy,
intersperse,
isPrefixOf,
partition,
unzip4,
zip4,
)
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR
import Futhark.IR.Aliases
( Aliases,
CanBeAliased (..),
)
import Futhark.IR.Mem
import Futhark.IR.Prop.Aliases
import Futhark.IR.TypeCheck qualified as TC
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (chunks, maybeNth)
import Futhark.Util.Pretty
( Doc,
apply,
hsep,
parens,
ppTuple',
pretty,
(<+>),
(</>),
)
import Futhark.Util.Pretty qualified as PP
import Prelude hiding (id, (.))
data HistOp rep = HistOp
{ forall rep. HistOp rep -> ShapeBase SubExp
histShape :: Shape,
forall rep. HistOp rep -> SubExp
histRaceFactor :: SubExp,
forall rep. HistOp rep -> [VName]
histDest :: [VName],
forall rep. HistOp rep -> [SubExp]
histNeutral :: [SubExp],
forall rep. HistOp rep -> ShapeBase SubExp
histOpShape :: Shape,
forall rep. HistOp rep -> Lambda rep
histOp :: Lambda rep
}
deriving (HistOp rep -> HistOp rep -> Bool
(HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool) -> Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
== :: HistOp rep -> HistOp rep -> Bool
$c/= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
/= :: HistOp rep -> HistOp rep -> Bool
Eq, Eq (HistOp rep)
Eq (HistOp rep) =>
(HistOp rep -> HistOp rep -> Ordering)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> HistOp rep)
-> (HistOp rep -> HistOp rep -> HistOp rep)
-> Ord (HistOp rep)
HistOp rep -> HistOp rep -> Bool
HistOp rep -> HistOp rep -> Ordering
HistOp rep -> HistOp rep -> HistOp rep
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
$ccompare :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
compare :: HistOp rep -> HistOp rep -> Ordering
$c< :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
< :: HistOp rep -> HistOp rep -> Bool
$c<= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
<= :: HistOp rep -> HistOp rep -> Bool
$c> :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
> :: HistOp rep -> HistOp rep -> Bool
$c>= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
>= :: HistOp rep -> HistOp rep -> Bool
$cmax :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
max :: HistOp rep -> HistOp rep -> HistOp rep
$cmin :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
min :: HistOp rep -> HistOp rep -> HistOp rep
Ord, Int -> HistOp rep -> ShowS
[HistOp rep] -> ShowS
HistOp rep -> String
(Int -> HistOp rep -> ShowS)
-> (HistOp rep -> String)
-> ([HistOp rep] -> ShowS)
-> Show (HistOp rep)
forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
forall rep. RepTypes rep => [HistOp rep] -> ShowS
forall rep. RepTypes rep => HistOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
showsPrec :: Int -> HistOp rep -> ShowS
$cshow :: forall rep. RepTypes rep => HistOp rep -> String
show :: HistOp rep -> String
$cshowList :: forall rep. RepTypes rep => [HistOp rep] -> ShowS
showList :: [HistOp rep] -> ShowS
Show)
histType :: HistOp rep -> [Type]
histType :: forall rep. HistOp rep -> [Type]
histType HistOp rep
op =
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` (HistOp rep -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp rep
op ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> HistOp rep -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp rep
op)) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$
HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op
splitHistResults :: [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults :: forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp rep]
ops [SubExp]
res =
let ranks :: [Int]
ranks = (HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank (ShapeBase SubExp -> Int)
-> (HistOp rep -> ShapeBase SubExp) -> HistOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape) [HistOp rep]
ops
([SubExp]
idxs, [SubExp]
vals) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
ranks) [SubExp]
res
in [[SubExp]] -> [[SubExp]] -> [([SubExp], [SubExp])]
forall a b. [a] -> [b] -> [(a, b)]
zip
([Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ranks [SubExp]
idxs)
([Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (HistOp rep -> [VName]) -> HistOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp rep]
ops) [SubExp]
vals)
data SegBinOp rep = SegBinOp
{ forall rep. SegBinOp rep -> Commutativity
segBinOpComm :: Commutativity,
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda :: Lambda rep,
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral :: [SubExp],
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape :: Shape
}
deriving (SegBinOp rep -> SegBinOp rep -> Bool
(SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool) -> Eq (SegBinOp rep)
forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
== :: SegBinOp rep -> SegBinOp rep -> Bool
$c/= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
/= :: SegBinOp rep -> SegBinOp rep -> Bool
Eq, Eq (SegBinOp rep)
Eq (SegBinOp rep) =>
(SegBinOp rep -> SegBinOp rep -> Ordering)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> SegBinOp rep)
-> (SegBinOp rep -> SegBinOp rep -> SegBinOp rep)
-> Ord (SegBinOp rep)
SegBinOp rep -> SegBinOp rep -> Bool
SegBinOp rep -> SegBinOp rep -> Ordering
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (SegBinOp rep)
forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Ordering
forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
$ccompare :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Ordering
compare :: SegBinOp rep -> SegBinOp rep -> Ordering
$c< :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
< :: SegBinOp rep -> SegBinOp rep -> Bool
$c<= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
<= :: SegBinOp rep -> SegBinOp rep -> Bool
$c> :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
> :: SegBinOp rep -> SegBinOp rep -> Bool
$c>= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
>= :: SegBinOp rep -> SegBinOp rep -> Bool
$cmax :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
max :: SegBinOp rep -> SegBinOp rep -> SegBinOp rep
$cmin :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
min :: SegBinOp rep -> SegBinOp rep -> SegBinOp rep
Ord, Int -> SegBinOp rep -> ShowS
[SegBinOp rep] -> ShowS
SegBinOp rep -> String
(Int -> SegBinOp rep -> ShowS)
-> (SegBinOp rep -> String)
-> ([SegBinOp rep] -> ShowS)
-> Show (SegBinOp rep)
forall rep. RepTypes rep => Int -> SegBinOp rep -> ShowS
forall rep. RepTypes rep => [SegBinOp rep] -> ShowS
forall rep. RepTypes rep => SegBinOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> SegBinOp rep -> ShowS
showsPrec :: Int -> SegBinOp rep -> ShowS
$cshow :: forall rep. RepTypes rep => SegBinOp rep -> String
show :: SegBinOp rep -> String
$cshowList :: forall rep. RepTypes rep => [SegBinOp rep] -> ShowS
showList :: [SegBinOp rep] -> ShowS
Show)
segBinOpResults :: [SegBinOp rep] -> Int
segBinOpResults :: forall rep. [SegBinOp rep] -> Int
segBinOpResults = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int)
-> ([SegBinOp rep] -> [Int]) -> [SegBinOp rep] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SegBinOp rep -> Int) -> [SegBinOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp rep -> [SubExp]) -> SegBinOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral)
segBinOpChunks :: [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks :: forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks = [Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [a] -> [[a]])
-> ([SegBinOp rep] -> [Int]) -> [SegBinOp rep] -> [a] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SegBinOp rep -> Int) -> [SegBinOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp rep -> [SubExp]) -> SegBinOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral)
data KernelBody rep = KernelBody
{ forall rep. KernelBody rep -> BodyDec rep
kernelBodyDec :: BodyDec rep,
forall rep. KernelBody rep -> Stms rep
kernelBodyStms :: Stms rep,
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult :: [KernelResult]
}
deriving instance (RepTypes rep) => Ord (KernelBody rep)
deriving instance (RepTypes rep) => Show (KernelBody rep)
deriving instance (RepTypes rep) => Eq (KernelBody rep)
data ResultManifest
=
ResultNoSimplify
|
ResultMaySimplify
|
ResultPrivate
deriving (ResultManifest -> ResultManifest -> Bool
(ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool) -> Eq ResultManifest
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ResultManifest -> ResultManifest -> Bool
== :: ResultManifest -> ResultManifest -> Bool
$c/= :: ResultManifest -> ResultManifest -> Bool
/= :: ResultManifest -> ResultManifest -> Bool
Eq, Int -> ResultManifest -> ShowS
[ResultManifest] -> ShowS
ResultManifest -> String
(Int -> ResultManifest -> ShowS)
-> (ResultManifest -> String)
-> ([ResultManifest] -> ShowS)
-> Show ResultManifest
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ResultManifest -> ShowS
showsPrec :: Int -> ResultManifest -> ShowS
$cshow :: ResultManifest -> String
show :: ResultManifest -> String
$cshowList :: [ResultManifest] -> ShowS
showList :: [ResultManifest] -> ShowS
Show, Eq ResultManifest
Eq ResultManifest =>
(ResultManifest -> ResultManifest -> Ordering)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> Ord ResultManifest
ResultManifest -> ResultManifest -> Bool
ResultManifest -> ResultManifest -> Ordering
ResultManifest -> ResultManifest -> ResultManifest
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: ResultManifest -> ResultManifest -> Ordering
compare :: ResultManifest -> ResultManifest -> Ordering
$c< :: ResultManifest -> ResultManifest -> Bool
< :: ResultManifest -> ResultManifest -> Bool
$c<= :: ResultManifest -> ResultManifest -> Bool
<= :: ResultManifest -> ResultManifest -> Bool
$c> :: ResultManifest -> ResultManifest -> Bool
> :: ResultManifest -> ResultManifest -> Bool
$c>= :: ResultManifest -> ResultManifest -> Bool
>= :: ResultManifest -> ResultManifest -> Bool
$cmax :: ResultManifest -> ResultManifest -> ResultManifest
max :: ResultManifest -> ResultManifest -> ResultManifest
$cmin :: ResultManifest -> ResultManifest -> ResultManifest
min :: ResultManifest -> ResultManifest -> ResultManifest
Ord)
data KernelResult
=
Returns ResultManifest Certs SubExp
| WriteReturns
Certs
VName
[(Slice SubExp, SubExp)]
| TileReturns
Certs
[(SubExp, SubExp)]
VName
| RegTileReturns
Certs
[ ( SubExp,
SubExp,
SubExp
)
]
VName
deriving (KernelResult -> KernelResult -> Bool
(KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool) -> Eq KernelResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: KernelResult -> KernelResult -> Bool
== :: KernelResult -> KernelResult -> Bool
$c/= :: KernelResult -> KernelResult -> Bool
/= :: KernelResult -> KernelResult -> Bool
Eq, Int -> KernelResult -> ShowS
[KernelResult] -> ShowS
KernelResult -> String
(Int -> KernelResult -> ShowS)
-> (KernelResult -> String)
-> ([KernelResult] -> ShowS)
-> Show KernelResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> KernelResult -> ShowS
showsPrec :: Int -> KernelResult -> ShowS
$cshow :: KernelResult -> String
show :: KernelResult -> String
$cshowList :: [KernelResult] -> ShowS
showList :: [KernelResult] -> ShowS
Show, Eq KernelResult
Eq KernelResult =>
(KernelResult -> KernelResult -> Ordering)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> KernelResult)
-> (KernelResult -> KernelResult -> KernelResult)
-> Ord KernelResult
KernelResult -> KernelResult -> Bool
KernelResult -> KernelResult -> Ordering
KernelResult -> KernelResult -> KernelResult
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: KernelResult -> KernelResult -> Ordering
compare :: KernelResult -> KernelResult -> Ordering
$c< :: KernelResult -> KernelResult -> Bool
< :: KernelResult -> KernelResult -> Bool
$c<= :: KernelResult -> KernelResult -> Bool
<= :: KernelResult -> KernelResult -> Bool
$c> :: KernelResult -> KernelResult -> Bool
> :: KernelResult -> KernelResult -> Bool
$c>= :: KernelResult -> KernelResult -> Bool
>= :: KernelResult -> KernelResult -> Bool
$cmax :: KernelResult -> KernelResult -> KernelResult
max :: KernelResult -> KernelResult -> KernelResult
$cmin :: KernelResult -> KernelResult -> KernelResult
min :: KernelResult -> KernelResult -> KernelResult
Ord)
kernelResultCerts :: KernelResult -> Certs
kernelResultCerts :: KernelResult -> Certs
kernelResultCerts (Returns ResultManifest
_ Certs
cs SubExp
_) = Certs
cs
kernelResultCerts (WriteReturns Certs
cs VName
_ [(Slice SubExp, SubExp)]
_) = Certs
cs
kernelResultCerts (TileReturns Certs
cs [(SubExp, SubExp)]
_ VName
_) = Certs
cs
kernelResultCerts (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
_ VName
_) = Certs
cs
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp (Returns ResultManifest
_ Certs
_ SubExp
se) = SubExp
se
kernelResultSubExp (WriteReturns Certs
_ VName
arr [(Slice SubExp, SubExp)]
_) = VName -> SubExp
Var VName
arr
kernelResultSubExp (TileReturns Certs
_ [(SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v
kernelResultSubExp (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v
instance FreeIn KernelResult where
freeIn' :: KernelResult -> FV
freeIn' (Returns ResultManifest
_ Certs
cs SubExp
what) = Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
what
freeIn' (WriteReturns Certs
cs VName
arr [(Slice SubExp, SubExp)]
res) = Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
arr FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(Slice SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(Slice SubExp, SubExp)]
res
freeIn' (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp)]
dims FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v
freeIn' (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(SubExp, SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp, SubExp)]
dims_n_tiles FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v
instance (ASTRep rep) => FreeIn (KernelBody rep) where
freeIn' :: KernelBody rep -> FV
freeIn' (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
Names -> FV -> FV
fvBind Names
bound_in_stms (FV -> FV) -> FV -> FV
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> FV
forall a. FreeIn a => a -> FV
freeIn' BodyDec rep
dec FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Stms rep -> FV
forall a. FreeIn a => a -> FV
freeIn' Stms rep
stms FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [KernelResult] -> FV
forall a. FreeIn a => a -> FV
freeIn' [KernelResult]
res
where
bound_in_stms :: Names
bound_in_stms = (Stm rep -> Names) -> Stms rep -> Names
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm rep -> Names
forall rep. Stm rep -> Names
boundByStm Stms rep
stms
instance (ASTRep rep) => Substitute (KernelBody rep) where
substituteNames :: Map VName VName -> KernelBody rep -> KernelBody rep
substituteNames Map VName VName
subst (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody
(Map VName VName -> BodyDec rep -> BodyDec rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst BodyDec rep
dec)
(Map VName VName -> Stms rep -> Stms rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Stms rep
stms)
(Map VName VName -> [KernelResult] -> [KernelResult]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [KernelResult]
res)
instance Substitute KernelResult where
substituteNames :: Map VName VName -> KernelResult -> KernelResult
substituteNames Map VName VName
subst (Returns ResultManifest
manifest Certs
cs SubExp
se) =
ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs) (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
se)
substituteNames Map VName VName
subst (WriteReturns Certs
cs VName
arr [(Slice SubExp, SubExp)]
res) =
Certs -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns
(Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
(Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
arr)
(Map VName VName
-> [(Slice SubExp, SubExp)] -> [(Slice SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(Slice SubExp, SubExp)]
res)
substituteNames Map VName VName
subst (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
Certs -> [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns
(Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
(Map VName VName -> [(SubExp, SubExp)] -> [(SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp)]
dims)
(Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
substituteNames Map VName VName
subst (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
(Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
(Map VName VName
-> [(SubExp, SubExp, SubExp)] -> [(SubExp, SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp, SubExp)]
dims_n_tiles)
(Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
instance (ASTRep rep) => Rename (KernelBody rep) where
rename :: KernelBody rep -> RenameM (KernelBody rep)
rename (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) = do
BodyDec rep
dec' <- BodyDec rep -> RenameM (BodyDec rep)
forall a. Rename a => a -> RenameM a
rename BodyDec rep
dec
Stms rep
-> (Stms rep -> RenameM (KernelBody rep))
-> RenameM (KernelBody rep)
forall rep a.
Renameable rep =>
Stms rep -> (Stms rep -> RenameM a) -> RenameM a
renamingStms Stms rep
stms ((Stms rep -> RenameM (KernelBody rep))
-> RenameM (KernelBody rep))
-> (Stms rep -> RenameM (KernelBody rep))
-> RenameM (KernelBody rep)
forall a b. (a -> b) -> a -> b
$ \Stms rep
stms' ->
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec' Stms rep
stms' ([KernelResult] -> KernelBody rep)
-> RenameM [KernelResult] -> RenameM (KernelBody rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [KernelResult] -> RenameM [KernelResult]
forall a. Rename a => a -> RenameM a
rename [KernelResult]
res
instance Rename KernelResult where
rename :: KernelResult -> RenameM KernelResult
rename = KernelResult -> RenameM KernelResult
forall a. Substitute a => a -> RenameM a
substituteRename
aliasAnalyseKernelBody ::
(Alias.AliasableRep rep) =>
AliasTable ->
KernelBody rep ->
KernelBody (Aliases rep)
aliasAnalyseKernelBody :: forall rep.
AliasableRep rep =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
aliases (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
let Body BodyDec (Aliases rep)
dec' Stms (Aliases rep)
stms' Result
_ = AliasTable -> Body rep -> Body (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
Alias.analyseBody AliasTable
aliases (Body rep -> Body (Aliases rep)) -> Body rep -> Body (Aliases rep)
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms []
in BodyDec (Aliases rep)
-> Stms (Aliases rep) -> [KernelResult] -> KernelBody (Aliases rep)
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Aliases rep)
dec' Stms (Aliases rep)
stms' [KernelResult]
res
consumedInKernelBody ::
(Aliased rep) =>
KernelBody rep ->
Names
consumedInKernelBody :: forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
Body rep -> Names
forall rep. Aliased rep => Body rep -> Names
consumedInBody (BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms []) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((KernelResult -> Names) -> [KernelResult] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Names
consumedByReturn [KernelResult]
res)
where
consumedByReturn :: KernelResult -> Names
consumedByReturn (WriteReturns Certs
_ VName
a [(Slice SubExp, SubExp)]
_) = VName -> Names
oneName VName
a
consumedByReturn KernelResult
_ = Names
forall a. Monoid a => a
mempty
checkKernelBody ::
(TC.Checkable rep) =>
[Type] ->
KernelBody (Aliases rep) ->
TC.TypeM rep ()
checkKernelBody :: forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts (KernelBody (BodyAliasing
_, BodyDec rep
dec) Stms (Aliases rep)
stms [KernelResult]
kres) = do
BodyDec rep -> TypeM rep ()
forall rep. Checkable rep => BodyDec rep -> TypeM rep ()
TC.checkBodyDec BodyDec rep
dec
(KernelResult -> TypeM rep ()) -> [KernelResult] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelResult -> TypeM rep ()
forall {rep}. Checkable rep => KernelResult -> TypeM rep ()
consumeKernelResult [KernelResult]
kres
Stms (Aliases rep) -> TypeM rep () -> TypeM rep ()
forall rep a.
Checkable rep =>
Stms (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.checkStms Stms (Aliases rep)
stms (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ do
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
Text
"Kernel return type is "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
ts
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
", but body returns "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Pretty a => a -> Text
prettyText ([KernelResult] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres)
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" values."
(KernelResult -> Type -> TypeM rep ())
-> [KernelResult] -> [Type] -> TypeM rep ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ KernelResult -> Type -> TypeM rep ()
forall {rep}. Checkable rep => KernelResult -> Type -> TypeM rep ()
checkKernelResult [KernelResult]
kres [Type]
ts
where
consumeKernelResult :: KernelResult -> TypeM rep ()
consumeKernelResult (WriteReturns Certs
_ VName
arr [(Slice SubExp, SubExp)]
_) =
Names -> TypeM rep ()
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume (Names -> TypeM rep ()) -> TypeM rep Names -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM rep Names
forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
arr
consumeKernelResult KernelResult
_ =
() -> TypeM rep ()
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
checkKernelResult :: KernelResult -> Type -> TypeM rep ()
checkKernelResult (Returns ResultManifest
_ Certs
cs SubExp
what) Type
t = do
Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [Type
t] SubExp
what
checkKernelResult (WriteReturns Certs
cs VName
arr [(Slice SubExp, SubExp)]
res) Type
t = do
Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
Type
arr_t <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
Text
"WriteReturns result type annotation for "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> VName -> Text
forall a. Pretty a => a -> Text
prettyText VName
arr
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" is "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Type -> Text
forall a. Pretty a => a -> Text
prettyText Type
t
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
", but inferred as"
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Type -> Text
forall a. Pretty a => a -> Text
prettyText Type
arr_t
[(Slice SubExp, SubExp)]
-> ((Slice SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Slice SubExp, SubExp)]
res (((Slice SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ())
-> ((Slice SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
e) -> do
Type -> Slice SubExp -> TypeM rep ()
forall rep. Checkable rep => Type -> Slice SubExp -> TypeM rep ()
TC.checkSlice Type
arr_t Slice SubExp
slice
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [Type
t Type -> ShapeBase SubExp -> Type
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` Slice SubExp -> ShapeBase SubExp
forall d. Slice d -> ShapeBase d
sliceShape Slice SubExp
slice] SubExp
e
checkKernelResult (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) Type
t = do
Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
[(SubExp, SubExp)]
-> ((SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(SubExp, SubExp)]
dims (((SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ())
-> ((SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
dim, SubExp
tile) -> do
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
dim
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
tile
Type
vt <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims)) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep) -> Text -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
Text
"Invalid type for TileReturns " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> VName -> Text
forall a. Pretty a => a -> Text
prettyText VName
v
checkKernelResult (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
arr) Type
t = do
Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
(SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
dims
(SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
blk_tiles
(SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
reg_tiles
Type
arr_t <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
expected) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
Text
"Invalid type for TileReturns. Expected:\n "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Type -> Text
forall a. Pretty a => a -> Text
prettyText Type
expected
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
",\ngot:\n "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Type -> Text
forall a. Pretty a => a -> Text
prettyText Type
arr_t
where
([SubExp]
dims, [SubExp]
blk_tiles, [SubExp]
reg_tiles) = [(SubExp, SubExp, SubExp)] -> ([SubExp], [SubExp], [SubExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, SubExp)]
dims_n_tiles
expected :: Type
expected = Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp]
blk_tiles [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp]
reg_tiles)
kernelBodyMetrics :: (OpMetrics (Op rep)) => KernelBody rep -> MetricsM ()
kernelBodyMetrics :: forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics = (Stm rep -> MetricsM ()) -> Seq (Stm rep) -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Stm rep -> MetricsM ()
stmMetrics (Seq (Stm rep) -> MetricsM ())
-> (KernelBody rep -> Seq (Stm rep))
-> KernelBody rep
-> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody rep -> Seq (Stm rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms
instance (PrettyRep rep) => Pretty (KernelBody rep) where
pretty :: forall ann. KernelBody rep -> Doc ann
pretty (KernelBody BodyDec rep
_ Stms rep
stms [KernelResult]
res) =
[Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.stack ((Stm rep -> Doc ann) -> [Stm rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Stm rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Stm rep -> Doc ann
pretty (Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms))
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"return"
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.commastack ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (KernelResult -> Doc ann) -> [KernelResult] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. KernelResult -> Doc ann
pretty [KernelResult]
res)
certAnnots :: Certs -> [Doc ann]
certAnnots :: forall ann. Certs -> [Doc ann]
certAnnots Certs
cs
| Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty = []
| Bool
otherwise = [Certs -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Certs -> Doc ann
pretty Certs
cs]
instance Pretty KernelResult where
pretty :: forall ann. KernelResult -> Doc ann
pretty (Returns ResultManifest
ResultNoSimplify Certs
cs SubExp
what) =
[Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
hsep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc ann]
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs [Doc ann] -> [Doc ann] -> [Doc ann]
forall a. Semigroup a => a -> a -> a
<> [Doc ann
"returns (manifest)" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
what]
pretty (Returns ResultManifest
ResultPrivate Certs
cs SubExp
what) =
[Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
hsep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc ann]
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs [Doc ann] -> [Doc ann] -> [Doc ann]
forall a. Semigroup a => a -> a -> a
<> [Doc ann
"returns (private)" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
what]
pretty (Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
what) =
[Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
hsep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc ann]
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs [Doc ann] -> [Doc ann] -> [Doc ann]
forall a. Semigroup a => a -> a -> a
<> [Doc ann
"returns" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
what]
pretty (WriteReturns Certs
cs VName
arr [(Slice SubExp, SubExp)]
res) =
[Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
hsep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$
Certs -> [Doc ann]
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs
[Doc ann] -> [Doc ann] -> [Doc ann]
forall a. Semigroup a => a -> a -> a
<> [VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty VName
arr Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"with" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.apply (((Slice SubExp, SubExp) -> Doc ann)
-> [(Slice SubExp, SubExp)] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map (Slice SubExp, SubExp) -> Doc ann
forall {a} {a} {ann}. (Pretty a, Pretty a) => (a, a) -> Doc ann
ppRes [(Slice SubExp, SubExp)]
res)]
where
ppRes :: (a, a) -> Doc ann
ppRes (a
slice, a
e) = a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
slice Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"=" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
e
pretty (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
[Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
hsep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc ann]
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs [Doc ann] -> [Doc ann] -> [Doc ann]
forall a. Semigroup a => a -> a -> a
<> [Doc ann
"tile" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
apply (((SubExp, SubExp) -> Doc ann) -> [(SubExp, SubExp)] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> Doc ann
forall {a} {a} {ann}. (Pretty a, Pretty a) => (a, a) -> Doc ann
onDim [(SubExp, SubExp)]
dims) Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty VName
v]
where
onDim :: (a, a) -> Doc ann
onDim (a
dim, a
tile) = a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
dim Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"/" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
tile
pretty (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
[Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
hsep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc ann]
forall ann. Certs -> [Doc ann]
certAnnots Certs
cs [Doc ann] -> [Doc ann] -> [Doc ann]
forall a. Semigroup a => a -> a -> a
<> [Doc ann
"blkreg_tile" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
apply (((SubExp, SubExp, SubExp) -> Doc ann)
-> [(SubExp, SubExp, SubExp)] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp, SubExp) -> Doc ann
forall {a} {a} {a} {ann}.
(Pretty a, Pretty a, Pretty a) =>
(a, a, a) -> Doc ann
onDim [(SubExp, SubExp, SubExp)]
dims_n_tiles) Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty VName
v]
where
onDim :: (a, a, a) -> Doc ann
onDim (a
dim, a
blk_tile, a
reg_tile) =
a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
dim Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"/" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
blk_tile Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"*" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
reg_tile)
data SegSpace = SegSpace
{
SegSpace -> VName
segFlat :: VName,
SegSpace -> [(VName, SubExp)]
unSegSpace :: [(VName, SubExp)]
}
deriving (SegSpace -> SegSpace -> Bool
(SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool) -> Eq SegSpace
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SegSpace -> SegSpace -> Bool
== :: SegSpace -> SegSpace -> Bool
$c/= :: SegSpace -> SegSpace -> Bool
/= :: SegSpace -> SegSpace -> Bool
Eq, Eq SegSpace
Eq SegSpace =>
(SegSpace -> SegSpace -> Ordering)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> SegSpace)
-> (SegSpace -> SegSpace -> SegSpace)
-> Ord SegSpace
SegSpace -> SegSpace -> Bool
SegSpace -> SegSpace -> Ordering
SegSpace -> SegSpace -> SegSpace
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: SegSpace -> SegSpace -> Ordering
compare :: SegSpace -> SegSpace -> Ordering
$c< :: SegSpace -> SegSpace -> Bool
< :: SegSpace -> SegSpace -> Bool
$c<= :: SegSpace -> SegSpace -> Bool
<= :: SegSpace -> SegSpace -> Bool
$c> :: SegSpace -> SegSpace -> Bool
> :: SegSpace -> SegSpace -> Bool
$c>= :: SegSpace -> SegSpace -> Bool
>= :: SegSpace -> SegSpace -> Bool
$cmax :: SegSpace -> SegSpace -> SegSpace
max :: SegSpace -> SegSpace -> SegSpace
$cmin :: SegSpace -> SegSpace -> SegSpace
min :: SegSpace -> SegSpace -> SegSpace
Ord, Int -> SegSpace -> ShowS
[SegSpace] -> ShowS
SegSpace -> String
(Int -> SegSpace -> ShowS)
-> (SegSpace -> String) -> ([SegSpace] -> ShowS) -> Show SegSpace
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SegSpace -> ShowS
showsPrec :: Int -> SegSpace -> ShowS
$cshow :: SegSpace -> String
show :: SegSpace -> String
$cshowList :: [SegSpace] -> ShowS
showList :: [SegSpace] -> ShowS
Show)
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims (SegSpace VName
_ [(VName, SubExp)]
space) = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
space
scopeOfSegSpace :: SegSpace -> Scope rep
scopeOfSegSpace :: forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegSpace VName
phys [(VName, SubExp)]
space) =
[(VName, NameInfo rep)] -> Map VName (NameInfo rep)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, NameInfo rep)] -> Map VName (NameInfo rep))
-> [(VName, NameInfo rep)] -> Map VName (NameInfo rep)
forall a b. (a -> b) -> a -> b
$ (VName -> (VName, NameInfo rep))
-> [VName] -> [(VName, NameInfo rep)]
forall a b. (a -> b) -> [a] -> [b]
map (,IntType -> NameInfo rep
forall rep. IntType -> NameInfo rep
IndexName IntType
Int64) (VName
phys VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
space)
checkSegSpace :: (TC.Checkable rep) => SegSpace -> TC.TypeM rep ()
checkSegSpace :: forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace (SegSpace VName
_ [(VName, SubExp)]
dims) =
((VName, SubExp) -> TypeM rep ())
-> [(VName, SubExp)] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] (SubExp -> TypeM rep ())
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
dims
data SegOp lvl rep
= SegMap lvl SegSpace [Type] (KernelBody rep)
|
SegRed lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep)
| SegScan lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep)
| SegHist lvl SegSpace [HistOp rep] [Type] (KernelBody rep)
deriving (SegOp lvl rep -> SegOp lvl rep -> Bool
(SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool) -> Eq (SegOp lvl rep)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
$c== :: forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
== :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c/= :: forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
/= :: SegOp lvl rep -> SegOp lvl rep -> Bool
Eq, Eq (SegOp lvl rep)
Eq (SegOp lvl rep) =>
(SegOp lvl rep -> SegOp lvl rep -> Ordering)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep)
-> (SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep)
-> Ord (SegOp lvl rep)
SegOp lvl rep -> SegOp lvl rep -> Bool
SegOp lvl rep -> SegOp lvl rep -> Ordering
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lvl rep. (RepTypes rep, Ord lvl) => Eq (SegOp lvl rep)
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Ordering
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
$ccompare :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Ordering
compare :: SegOp lvl rep -> SegOp lvl rep -> Ordering
$c< :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
< :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c<= :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
<= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c> :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
> :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c>= :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
>= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$cmax :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
max :: SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
$cmin :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
min :: SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
Ord, Int -> SegOp lvl rep -> ShowS
[SegOp lvl rep] -> ShowS
SegOp lvl rep -> String
(Int -> SegOp lvl rep -> ShowS)
-> (SegOp lvl rep -> String)
-> ([SegOp lvl rep] -> ShowS)
-> Show (SegOp lvl rep)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall lvl rep.
(RepTypes rep, Show lvl) =>
Int -> SegOp lvl rep -> ShowS
forall lvl rep.
(RepTypes rep, Show lvl) =>
[SegOp lvl rep] -> ShowS
forall lvl rep. (RepTypes rep, Show lvl) => SegOp lvl rep -> String
$cshowsPrec :: forall lvl rep.
(RepTypes rep, Show lvl) =>
Int -> SegOp lvl rep -> ShowS
showsPrec :: Int -> SegOp lvl rep -> ShowS
$cshow :: forall lvl rep. (RepTypes rep, Show lvl) => SegOp lvl rep -> String
show :: SegOp lvl rep -> String
$cshowList :: forall lvl rep.
(RepTypes rep, Show lvl) =>
[SegOp lvl rep] -> ShowS
showList :: [SegOp lvl rep] -> ShowS
Show)
segLevel :: SegOp lvl rep -> lvl
segLevel :: forall lvl rep. SegOp lvl rep -> lvl
segLevel (SegMap lvl
lvl SegSpace
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegRed lvl
lvl SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegScan lvl
lvl SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegHist lvl
lvl SegSpace
_ [HistOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segSpace :: SegOp lvl rep -> SegSpace
segSpace :: forall lvl rep. SegOp lvl rep -> SegSpace
segSpace (SegMap lvl
_ SegSpace
lvl [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegRed lvl
_ SegSpace
lvl [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegScan lvl
_ SegSpace
lvl [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegHist lvl
_ SegSpace
lvl [HistOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segBody :: SegOp lvl rep -> KernelBody rep
segBody :: forall lvl rep. SegOp lvl rep -> KernelBody rep
segBody SegOp lvl rep
segop =
case SegOp lvl rep
segop of
SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
SegRed lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
SegScan lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
SegHist lvl
_ SegSpace
_ [HistOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
_ Type
t (WriteReturns {}) =
Type
t
segResultShape SegSpace
space Type
t Returns {} =
(SubExp -> Type -> Type) -> Type -> [SubExp] -> Type
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Type -> SubExp -> Type) -> SubExp -> Type -> Type
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow) Type
t ([SubExp] -> Type) -> [SubExp] -> Type
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segResultShape SegSpace
_ Type
t (TileReturns Certs
_ [(SubExp, SubExp)]
dims VName
_) =
Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims)
segResultShape SegSpace
_ Type
t (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
_) =
Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp, SubExp) -> SubExp)
-> [(SubExp, SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (\(SubExp
dim, SubExp
_, SubExp
_) -> SubExp
dim) [(SubExp, SubExp, SubExp)]
dims_n_tiles)
segOpType :: SegOp lvl rep -> [Type]
segOpType :: forall lvl rep. SegOp lvl rep -> [Type]
segOpType (SegMap lvl
_ SegSpace
space [Type]
ts KernelBody rep
kbody) =
(Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
ts ([KernelResult] -> [Type]) -> [KernelResult] -> [Type]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
segOpType (SegRed lvl
_ SegSpace
space [SegBinOp rep]
reds [Type]
ts KernelBody rep
kbody) =
[Type]
red_ts
[Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
[Type]
map_ts
(Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody)
where
map_ts :: [Type]
map_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) [Type]
ts
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
red_ts :: [Type]
red_ts = do
SegBinOp rep
op <- [SegBinOp rep]
reds
let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> SegBinOp rep -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op)
segOpType (SegScan lvl
_ SegSpace
space [SegBinOp rep]
scans [Type]
ts KernelBody rep
kbody) =
[Type]
scan_ts
[Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
[Type]
map_ts
(Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody)
where
map_ts :: [Type]
map_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) [Type]
ts
scan_ts :: [Type]
scan_ts = do
SegBinOp rep
op <- [SegBinOp rep]
scans
let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (SegSpace -> [SubExp]
segSpaceDims SegSpace
space) ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> SegBinOp rep -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op)
segOpType (SegHist lvl
_ SegSpace
space [HistOp rep]
ops [Type]
_ KernelBody rep
_) = do
HistOp rep
op <- [HistOp rep]
ops
let shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> HistOp rep -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp rep
op ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> HistOp rep -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp rep
op
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op)
where
dims :: [SubExp]
dims = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init [SubExp]
dims
instance TypedOp (SegOp lvl) where
opType :: forall rep (m :: * -> *).
HasScope rep m =>
SegOp lvl rep -> m [ExtType]
opType = [ExtType] -> m [ExtType]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType])
-> (SegOp lvl rep -> [ExtType]) -> SegOp lvl rep -> m [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Type] -> [ExtType]
forall u. [TypeBase (ShapeBase SubExp) u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType])
-> (SegOp lvl rep -> [Type]) -> SegOp lvl rep -> [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOp lvl rep -> [Type]
forall lvl rep. SegOp lvl rep -> [Type]
segOpType
instance (ASTConstraints lvl) => AliasedOp (SegOp lvl) where
opAliases :: forall rep. Aliased rep => SegOp lvl rep -> [Names]
opAliases = (Type -> Names) -> [Type] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Type -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([Type] -> [Names])
-> (SegOp lvl rep -> [Type]) -> SegOp lvl rep -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOp lvl rep -> [Type]
forall lvl rep. SegOp lvl rep -> [Type]
segOpType
consumedInOp :: forall rep. Aliased rep => SegOp lvl rep -> Names
consumedInOp (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
kbody) =
KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
consumedInOp (SegRed lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
consumedInOp (SegScan lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
consumedInOp (SegHist lvl
_ SegSpace
_ [HistOp rep]
ops [Type]
_ KernelBody rep
kbody) =
[VName] -> Names
namesFromList ((HistOp rep -> [VName]) -> [HistOp rep] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest [HistOp rep]
ops) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
typeCheckSegOp ::
(TC.Checkable rep) =>
(lvl -> TC.TypeM rep ()) ->
SegOp lvl (Aliases rep) ->
TC.TypeM rep ()
typeCheckSegOp :: forall rep lvl.
Checkable rep =>
(lvl -> TypeM rep ()) -> SegOp lvl (Aliases rep) -> TypeM rep ()
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Aliases rep)
kbody) = do
lvl -> TypeM rep ()
checkLvl lvl
lvl
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [] [Type]
ts KernelBody (Aliases rep)
kbody
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegRed lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
reds [Type]
ts KernelBody (Aliases rep)
body) = do
lvl -> TypeM rep ()
checkLvl lvl
lvl
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
reds' [Type]
ts KernelBody (Aliases rep)
body
where
reds' :: [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
reds' =
[Lambda (Aliases rep)]
-> [[SubExp]]
-> [ShapeBase SubExp]
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
((SegBinOp (Aliases rep) -> Lambda (Aliases rep))
-> [SegBinOp (Aliases rep)] -> [Lambda (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> Lambda (Aliases rep)
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp (Aliases rep)]
reds)
((SegBinOp (Aliases rep) -> [SubExp])
-> [SegBinOp (Aliases rep)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases rep)]
reds)
((SegBinOp (Aliases rep) -> ShapeBase SubExp)
-> [SegBinOp (Aliases rep)] -> [ShapeBase SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape [SegBinOp (Aliases rep)]
reds)
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegScan lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
scans [Type]
ts KernelBody (Aliases rep)
body) = do
lvl -> TypeM rep ()
checkLvl lvl
lvl
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
scans' [Type]
ts KernelBody (Aliases rep)
body
where
scans' :: [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
scans' =
[Lambda (Aliases rep)]
-> [[SubExp]]
-> [ShapeBase SubExp]
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
((SegBinOp (Aliases rep) -> Lambda (Aliases rep))
-> [SegBinOp (Aliases rep)] -> [Lambda (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> Lambda (Aliases rep)
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp (Aliases rep)]
scans)
((SegBinOp (Aliases rep) -> [SubExp])
-> [SegBinOp (Aliases rep)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases rep)]
scans)
((SegBinOp (Aliases rep) -> ShapeBase SubExp)
-> [SegBinOp (Aliases rep)] -> [ShapeBase SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape [SegBinOp (Aliases rep)]
scans)
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegHist lvl
lvl SegSpace
space [HistOp (Aliases rep)]
ops [Type]
ts KernelBody (Aliases rep)
kbody) = do
lvl -> TypeM rep ()
checkLvl lvl
lvl
SegSpace -> TypeM rep ()
forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace SegSpace
space
(Type -> TypeM rep ()) -> [Type] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM rep ()
forall rep u.
Checkable rep =>
TypeBase (ShapeBase SubExp) u -> TypeM rep ()
TC.checkType [Type]
ts
Scope (Aliases rep) -> TypeM rep () -> TypeM rep ()
forall rep a.
Checkable rep =>
Scope (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.binding (SegSpace -> Scope (Aliases rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ do
[[Type]]
nes_ts <- [HistOp (Aliases rep)]
-> (HistOp (Aliases rep) -> TypeM rep [Type]) -> TypeM rep [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Aliases rep)]
ops ((HistOp (Aliases rep) -> TypeM rep [Type]) -> TypeM rep [[Type]])
-> (HistOp (Aliases rep) -> TypeM rep [Type]) -> TypeM rep [[Type]]
forall a b. (a -> b) -> a -> b
$ \(HistOp ShapeBase SubExp
dest_shape SubExp
rf [VName]
dests [SubExp]
nes ShapeBase SubExp
shape Lambda (Aliases rep)
op) -> do
(SubExp -> TypeM rep ()) -> ShapeBase SubExp -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ShapeBase SubExp
dest_shape
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
rf
[Arg]
nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes
(SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
let stripVecDims :: Type -> Type
stripVecDims = Int -> Type -> Type
forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray (Int -> Type -> Type) -> Int -> Type -> Type
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shape
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
op ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map (Arg -> Arg
TC.noArgAliases (Arg -> Arg) -> (Arg -> Arg) -> Arg -> Arg
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Type -> Type) -> Arg -> Arg
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Type -> Type
stripVecDims) ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep) -> Text -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
Text
"SegHist operator has return type "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op)
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" but neutral element has type "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
nes_t
let dest_shape' :: ShapeBase SubExp
dest_shape' = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
dest_shape ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
shape
[(Type, VName)] -> ((Type, VName) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) (((Type, VName) -> TypeM rep ()) -> TypeM rep ())
-> ((Type, VName) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
[Type] -> VName -> TypeM rep ()
forall rep. Checkable rep => [Type] -> VName -> TypeM rep ()
TC.requireI [Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
dest_shape'] VName
dest
Names -> TypeM rep ()
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume (Names -> TypeM rep ()) -> TypeM rep Names -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM rep Names
forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
dest
[Type] -> TypeM rep [Type]
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type] -> TypeM rep [Type]) -> [Type] -> TypeM rep [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) [Type]
nes_t
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts KernelBody (Aliases rep)
kbody
let bucket_ret_t :: [Type]
bucket_ret_t =
(HistOp (Aliases rep) -> [Type])
-> [HistOp (Aliases rep)] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Int -> Type -> [Type]
forall a. Int -> a -> [a]
`replicate` PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) (Int -> [Type])
-> (HistOp (Aliases rep) -> Int) -> HistOp (Aliases rep) -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank (ShapeBase SubExp -> Int)
-> (HistOp (Aliases rep) -> ShapeBase SubExp)
-> HistOp (Aliases rep)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp (Aliases rep) -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape) [HistOp (Aliases rep)]
ops
[Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
nes_ts
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
ts) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep) -> Text -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
Text
"SegHist body has return type "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
ts
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" but should have type "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
bucket_ret_t
where
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
checkScanRed ::
(TC.Checkable rep) =>
SegSpace ->
[(Lambda (Aliases rep), [SubExp], Shape)] ->
[Type] ->
KernelBody (Aliases rep) ->
TC.TypeM rep ()
checkScanRed :: forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
ops [Type]
ts KernelBody (Aliases rep)
kbody = do
SegSpace -> TypeM rep ()
forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace SegSpace
space
(Type -> TypeM rep ()) -> [Type] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM rep ()
forall rep u.
Checkable rep =>
TypeBase (ShapeBase SubExp) u -> TypeM rep ()
TC.checkType [Type]
ts
Scope (Aliases rep) -> TypeM rep () -> TypeM rep ()
forall rep a.
Checkable rep =>
Scope (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.binding (SegSpace -> Scope (Aliases rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ do
[[Type]]
ne_ts <- [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> ((Lambda (Aliases rep), [SubExp], ShapeBase SubExp)
-> TypeM rep [Type])
-> TypeM rep [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
ops (((Lambda (Aliases rep), [SubExp], ShapeBase SubExp)
-> TypeM rep [Type])
-> TypeM rep [[Type]])
-> ((Lambda (Aliases rep), [SubExp], ShapeBase SubExp)
-> TypeM rep [Type])
-> TypeM rep [[Type]]
forall a b. (a -> b) -> a -> b
$ \(Lambda (Aliases rep)
lam, [SubExp]
nes, ShapeBase SubExp
shape) -> do
(SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
[Arg]
nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
nes_t) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError Text
"wrong type for operator or neutral elements."
[Type] -> TypeM rep [Type]
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type] -> TypeM rep [Type]) -> [Type] -> TypeM rep [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) [Type]
nes_t
let expecting :: [Type]
expecting = [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
ne_ts
got :: [Type]
got = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
expecting) [Type]
ts
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
expecting [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
got) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep) -> Text -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
Text
"Wrong return for body (does not match neutral elements; expected "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => a -> Text
prettyText [Type]
expecting
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"; found "
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => a -> Text
prettyText [Type]
got
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
")"
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts KernelBody (Aliases rep)
kbody
data SegOpMapper lvl frep trep m = SegOpMapper
{ forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp :: SubExp -> m SubExp,
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda :: Lambda frep -> m (Lambda trep),
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody :: KernelBody frep -> m (KernelBody trep),
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName :: VName -> m VName,
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel :: lvl -> m lvl
}
identitySegOpMapper :: (Monad m) => SegOpMapper lvl rep rep m
identitySegOpMapper :: forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper =
SegOpMapper
{ mapOnSegOpSubExp :: SubExp -> m SubExp
mapOnSegOpSubExp = SubExp -> m SubExp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnSegOpLambda :: Lambda rep -> m (Lambda rep)
mapOnSegOpLambda = Lambda rep -> m (Lambda rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnSegOpBody :: KernelBody rep -> m (KernelBody rep)
mapOnSegOpBody = KernelBody rep -> m (KernelBody rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnSegOpVName :: VName -> m VName
mapOnSegOpVName = VName -> m VName
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnSegOpLevel :: lvl -> m lvl
mapOnSegOpLevel = lvl -> m lvl
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
}
mapOnSegSpace ::
(Monad f) => SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace :: forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep f
tv (SegSpace VName
phys [(VName, SubExp)]
dims) =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace
(VName -> [(VName, SubExp)] -> SegSpace)
-> f VName -> f ([(VName, SubExp)] -> SegSpace)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep f -> VName -> f VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep f
tv VName
phys
f ([(VName, SubExp)] -> SegSpace)
-> f [(VName, SubExp)] -> f SegSpace
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((VName, SubExp) -> f (VName, SubExp))
-> [(VName, SubExp)] -> f [(VName, SubExp)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((VName -> f VName)
-> (SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp)
forall (f :: * -> *) a c b d.
Applicative f =>
(a -> f c) -> (b -> f d) -> (a, b) -> f (c, d)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse (SegOpMapper lvl frep trep f -> VName -> f VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep f
tv) (SegOpMapper lvl frep trep f -> SubExp -> f SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep f
tv)) [(VName, SubExp)]
dims
mapSegBinOp ::
(Monad m) =>
SegOpMapper lvl frep trep m ->
SegBinOp frep ->
m (SegBinOp trep)
mapSegBinOp :: forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv (SegBinOp Commutativity
comm Lambda frep
red_op [SubExp]
nes ShapeBase SubExp
shape) =
Commutativity
-> Lambda trep -> [SubExp] -> ShapeBase SubExp -> SegBinOp trep
forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm
(Lambda trep -> [SubExp] -> ShapeBase SubExp -> SegBinOp trep)
-> m (Lambda trep)
-> m ([SubExp] -> ShapeBase SubExp -> SegBinOp trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda SegOpMapper lvl frep trep m
tv Lambda frep
red_op
m ([SubExp] -> ShapeBase SubExp -> SegBinOp trep)
-> m [SubExp] -> m (ShapeBase SubExp -> SegBinOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [SubExp]
nes
m (ShapeBase SubExp -> SegBinOp trep)
-> m (ShapeBase SubExp) -> m (SegBinOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp)
-> m [SubExp] -> m (ShapeBase SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape))
mapSegOpM ::
(Monad m) =>
SegOpMapper lvl frep trep m ->
SegOp lvl frep ->
m (SegOp lvl trep)
mapSegOpM :: forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl frep trep m
tv (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody frep
body) =
lvl -> SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap
(lvl -> SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m lvl
-> m (SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
m (SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace -> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> Type -> m Type
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl frep trep m
tv) [Type]
ts
m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
mapSegOpM SegOpMapper lvl frep trep m
tv (SegRed lvl
lvl SegSpace
space [SegBinOp frep]
reds [Type]
ts KernelBody frep
lam) =
lvl
-> SegSpace
-> [SegBinOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed
(lvl
-> SegSpace
-> [SegBinOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep)
-> m lvl
-> m (SegSpace
-> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
m (SegSpace
-> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace
-> m ([SegBinOp trep]
-> [Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
m ([SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m [SegBinOp trep]
-> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SegBinOp frep -> m (SegBinOp trep))
-> [SegBinOp frep] -> m [SegBinOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv) [SegBinOp frep]
reds
m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
lam
mapSegOpM SegOpMapper lvl frep trep m
tv (SegScan lvl
lvl SegSpace
space [SegBinOp frep]
scans [Type]
ts KernelBody frep
body) =
lvl
-> SegSpace
-> [SegBinOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan
(lvl
-> SegSpace
-> [SegBinOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep)
-> m lvl
-> m (SegSpace
-> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
m (SegSpace
-> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace
-> m ([SegBinOp trep]
-> [Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
m ([SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m [SegBinOp trep]
-> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SegBinOp frep -> m (SegBinOp trep))
-> [SegBinOp frep] -> m [SegBinOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv) [SegBinOp frep]
scans
m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
mapSegOpM SegOpMapper lvl frep trep m
tv (SegHist lvl
lvl SegSpace
space [HistOp frep]
ops [Type]
ts KernelBody frep
body) =
lvl
-> SegSpace
-> [HistOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist
(lvl
-> SegSpace
-> [HistOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep)
-> m lvl
-> m (SegSpace
-> [HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
m (SegSpace
-> [HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace
-> m ([HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
m ([HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m [HistOp trep]
-> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (HistOp frep -> m (HistOp trep))
-> [HistOp frep] -> m [HistOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM HistOp frep -> m (HistOp trep)
onHistOp [HistOp frep]
ops
m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
where
onHistOp :: HistOp frep -> m (HistOp trep)
onHistOp (HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
shape Lambda frep
op) =
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda trep
-> HistOp trep
forall rep.
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda rep
-> HistOp rep
HistOp
(ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda trep
-> HistOp trep)
-> m (ShapeBase SubExp)
-> m (SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda trep
-> HistOp trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> ShapeBase SubExp -> m (ShapeBase SubExp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> ShapeBase a -> m (ShapeBase b)
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) ShapeBase SubExp
w
m (SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda trep
-> HistOp trep)
-> m SubExp
-> m ([VName]
-> [SubExp] -> ShapeBase SubExp -> Lambda trep -> HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv SubExp
rf
m ([VName]
-> [SubExp] -> ShapeBase SubExp -> Lambda trep -> HistOp trep)
-> m [VName]
-> m ([SubExp] -> ShapeBase SubExp -> Lambda trep -> HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> VName -> m VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep m
tv) [VName]
arrs
m ([SubExp] -> ShapeBase SubExp -> Lambda trep -> HistOp trep)
-> m [SubExp] -> m (ShapeBase SubExp -> Lambda trep -> HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [SubExp]
nes
m (ShapeBase SubExp -> Lambda trep -> HistOp trep)
-> m (ShapeBase SubExp) -> m (Lambda trep -> HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp)
-> m [SubExp] -> m (ShapeBase SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape))
m (Lambda trep -> HistOp trep)
-> m (Lambda trep) -> m (HistOp trep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda SegOpMapper lvl frep trep m
tv Lambda frep
op
mapOnSegOpType ::
(Monad m) =>
SegOpMapper lvl frep trep m ->
Type ->
m Type
mapOnSegOpType :: forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl frep trep m
_tv t :: Type
t@Prim {} = Type -> m Type
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t
mapOnSegOpType SegOpMapper lvl frep trep m
tv (Acc VName
acc ShapeBase SubExp
ispace [Type]
ts NoUniqueness
u) =
VName -> ShapeBase SubExp -> [Type] -> NoUniqueness -> Type
forall shape u.
VName -> ShapeBase SubExp -> [Type] -> u -> TypeBase shape u
Acc
(VName -> ShapeBase SubExp -> [Type] -> NoUniqueness -> Type)
-> m VName
-> m (ShapeBase SubExp -> [Type] -> NoUniqueness -> Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> VName -> m VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep m
tv VName
acc
m (ShapeBase SubExp -> [Type] -> NoUniqueness -> Type)
-> m (ShapeBase SubExp) -> m ([Type] -> NoUniqueness -> Type)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> ShapeBase SubExp -> m (ShapeBase SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> ShapeBase a -> f (ShapeBase b)
traverse (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) ShapeBase SubExp
ispace
m ([Type] -> NoUniqueness -> Type)
-> m [Type] -> m (NoUniqueness -> Type)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((ShapeBase SubExp -> m (ShapeBase SubExp))
-> (NoUniqueness -> m NoUniqueness) -> Type -> m Type
forall (f :: * -> *) a c b d.
Applicative f =>
(a -> f c) -> (b -> f d) -> TypeBase a b -> f (TypeBase c d)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse ((SubExp -> m SubExp) -> ShapeBase SubExp -> m (ShapeBase SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> ShapeBase a -> f (ShapeBase b)
traverse (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv)) NoUniqueness -> m NoUniqueness
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure) [Type]
ts
m (NoUniqueness -> Type) -> m NoUniqueness -> m Type
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NoUniqueness -> m NoUniqueness
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
mapOnSegOpType SegOpMapper lvl frep trep m
tv (Array PrimType
et ShapeBase SubExp
shape NoUniqueness
u) =
PrimType -> ShapeBase SubExp -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et (ShapeBase SubExp -> NoUniqueness -> Type)
-> m (ShapeBase SubExp) -> m (NoUniqueness -> Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> ShapeBase SubExp -> m (ShapeBase SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> ShapeBase a -> f (ShapeBase b)
traverse (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) ShapeBase SubExp
shape m (NoUniqueness -> Type) -> m NoUniqueness -> m Type
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NoUniqueness -> m NoUniqueness
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
mapOnSegOpType SegOpMapper lvl frep trep m
_tv (Mem Space
s) = Type -> m Type
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
s
rephraseBinOp ::
(Monad f) =>
Rephraser f from rep ->
SegBinOp from ->
f (SegBinOp rep)
rephraseBinOp :: forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> SegBinOp from -> f (SegBinOp rep)
rephraseBinOp Rephraser f from rep
r (SegBinOp Commutativity
comm Lambda from
lam [SubExp]
nes ShapeBase SubExp
shape) =
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm (Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep)
-> f (Lambda rep)
-> f ([SubExp] -> ShapeBase SubExp -> SegBinOp rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser f from rep -> Lambda from -> f (Lambda rep)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser f from rep
r Lambda from
lam f ([SubExp] -> ShapeBase SubExp -> SegBinOp rep)
-> f [SubExp] -> f (ShapeBase SubExp -> SegBinOp rep)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> f [SubExp]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes f (ShapeBase SubExp -> SegBinOp rep)
-> f (ShapeBase SubExp) -> f (SegBinOp rep)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ShapeBase SubExp -> f (ShapeBase SubExp)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ShapeBase SubExp
shape
rephraseKernelBody ::
(Monad f) =>
Rephraser f from rep ->
KernelBody from ->
f (KernelBody rep)
rephraseKernelBody :: forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser f from rep
r (KernelBody BodyDec from
dec Stms from
stms [KernelResult]
res) =
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody (BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep)
-> f (BodyDec rep)
-> f (Stms rep -> [KernelResult] -> KernelBody rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser f from rep -> BodyDec from -> f (BodyDec rep)
forall (m :: * -> *) from to.
Rephraser m from to -> BodyDec from -> m (BodyDec to)
rephraseBodyDec Rephraser f from rep
r BodyDec from
dec f (Stms rep -> [KernelResult] -> KernelBody rep)
-> f (Stms rep) -> f ([KernelResult] -> KernelBody rep)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Stm from -> f (Stm rep)) -> Stms from -> f (Stms rep)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Seq a -> f (Seq b)
traverse (Rephraser f from rep -> Stm from -> f (Stm rep)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm Rephraser f from rep
r) Stms from
stms f ([KernelResult] -> KernelBody rep)
-> f [KernelResult] -> f (KernelBody rep)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> f [KernelResult]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res
instance RephraseOp (SegOp lvl) where
rephraseInOp :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> SegOp lvl from -> m (SegOp lvl to)
rephraseInOp Rephraser m from to
r (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody from
body) =
lvl -> SegSpace -> [Type] -> KernelBody to -> SegOp lvl to
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
space [Type]
ts (KernelBody to -> SegOp lvl to)
-> m (KernelBody to) -> m (SegOp lvl to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> KernelBody from -> m (KernelBody to)
forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
rephraseInOp Rephraser m from to
r (SegRed lvl
lvl SegSpace
space [SegBinOp from]
reds [Type]
ts KernelBody from
body) =
lvl
-> SegSpace
-> [SegBinOp to]
-> [Type]
-> KernelBody to
-> SegOp lvl to
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
space
([SegBinOp to] -> [Type] -> KernelBody to -> SegOp lvl to)
-> m [SegBinOp to] -> m ([Type] -> KernelBody to -> SegOp lvl to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegBinOp from -> m (SegBinOp to))
-> [SegBinOp from] -> m [SegBinOp to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Rephraser m from to -> SegBinOp from -> m (SegBinOp to)
forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> SegBinOp from -> f (SegBinOp rep)
rephraseBinOp Rephraser m from to
r) [SegBinOp from]
reds
m ([Type] -> KernelBody to -> SegOp lvl to)
-> m [Type] -> m (KernelBody to -> SegOp lvl to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> m [Type]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ts
m (KernelBody to -> SegOp lvl to)
-> m (KernelBody to) -> m (SegOp lvl to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Rephraser m from to -> KernelBody from -> m (KernelBody to)
forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
rephraseInOp Rephraser m from to
r (SegScan lvl
lvl SegSpace
space [SegBinOp from]
scans [Type]
ts KernelBody from
body) =
lvl
-> SegSpace
-> [SegBinOp to]
-> [Type]
-> KernelBody to
-> SegOp lvl to
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
space
([SegBinOp to] -> [Type] -> KernelBody to -> SegOp lvl to)
-> m [SegBinOp to] -> m ([Type] -> KernelBody to -> SegOp lvl to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegBinOp from -> m (SegBinOp to))
-> [SegBinOp from] -> m [SegBinOp to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Rephraser m from to -> SegBinOp from -> m (SegBinOp to)
forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> SegBinOp from -> f (SegBinOp rep)
rephraseBinOp Rephraser m from to
r) [SegBinOp from]
scans
m ([Type] -> KernelBody to -> SegOp lvl to)
-> m [Type] -> m (KernelBody to -> SegOp lvl to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> m [Type]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ts
m (KernelBody to -> SegOp lvl to)
-> m (KernelBody to) -> m (SegOp lvl to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Rephraser m from to -> KernelBody from -> m (KernelBody to)
forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
rephraseInOp Rephraser m from to
r (SegHist lvl
lvl SegSpace
space [HistOp from]
hists [Type]
ts KernelBody from
body) =
lvl
-> SegSpace
-> [HistOp to]
-> [Type]
-> KernelBody to
-> SegOp lvl to
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
space
([HistOp to] -> [Type] -> KernelBody to -> SegOp lvl to)
-> m [HistOp to] -> m ([Type] -> KernelBody to -> SegOp lvl to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp from -> m (HistOp to)) -> [HistOp from] -> m [HistOp to]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM HistOp from -> m (HistOp to)
onOp [HistOp from]
hists
m ([Type] -> KernelBody to -> SegOp lvl to)
-> m [Type] -> m (KernelBody to -> SegOp lvl to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> m [Type]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ts
m (KernelBody to -> SegOp lvl to)
-> m (KernelBody to) -> m (SegOp lvl to)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Rephraser m from to -> KernelBody from -> m (KernelBody to)
forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
where
onOp :: HistOp from -> m (HistOp to)
onOp (HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
shape Lambda from
op) =
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda to
-> HistOp to
forall rep.
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda rep
-> HistOp rep
HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
shape (Lambda to -> HistOp to) -> m (Lambda to) -> m (HistOp to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
op
traverseSegOpStms :: (Monad m) => OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms :: forall (m :: * -> *) lvl rep.
Monad m =>
OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms Scope rep -> Stms rep -> m (Stms rep)
f SegOp lvl rep
segop = SegOpMapper lvl rep rep m -> SegOp lvl rep -> m (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep m
mapper SegOp lvl rep
segop
where
seg_scope :: Scope rep
seg_scope = SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp lvl rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
segop)
f' :: Scope rep -> Stms rep -> m (Stms rep)
f' Scope rep
scope = Scope rep -> Stms rep -> m (Stms rep)
f (Scope rep
seg_scope Scope rep -> Scope rep -> Scope rep
forall a. Semigroup a => a -> a -> a
<> Scope rep
scope)
mapper :: SegOpMapper lvl rep rep m
mapper =
SegOpMapper lvl Any Any m
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
{ mapOnSegOpLambda = traverseLambdaStms f',
mapOnSegOpBody = onBody
}
onBody :: KernelBody rep -> m (KernelBody rep)
onBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec (Stms rep -> [KernelResult] -> KernelBody rep)
-> m (Stms rep) -> m ([KernelResult] -> KernelBody rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope rep -> Stms rep -> m (Stms rep)
f Scope rep
seg_scope Stms rep
stms m ([KernelResult] -> KernelBody rep)
-> m [KernelResult] -> m (KernelBody rep)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> m [KernelResult]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res
instance
(ASTRep rep, Substitute lvl) =>
Substitute (SegOp lvl rep)
where
substituteNames :: Map VName VName -> SegOp lvl rep -> SegOp lvl rep
substituteNames Map VName VName
subst = Identity (SegOp lvl rep) -> SegOp lvl rep
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl rep) -> SegOp lvl rep)
-> (SegOp lvl rep -> Identity (SegOp lvl rep))
-> SegOp lvl rep
-> SegOp lvl rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl rep rep Identity
-> SegOp lvl rep -> Identity (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep Identity
substitute
where
substitute :: SegOpMapper lvl rep rep Identity
substitute =
SegOpMapper
{ mapOnSegOpSubExp :: SubExp -> Identity SubExp
mapOnSegOpSubExp = SubExp -> Identity SubExp
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpLambda :: Lambda rep -> Identity (Lambda rep)
mapOnSegOpLambda = Lambda rep -> Identity (Lambda rep)
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep -> Identity (Lambda rep))
-> (Lambda rep -> Lambda rep)
-> Lambda rep
-> Identity (Lambda rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> Lambda rep -> Lambda rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpBody :: KernelBody rep -> Identity (KernelBody rep)
mapOnSegOpBody = KernelBody rep -> Identity (KernelBody rep)
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody rep -> Identity (KernelBody rep))
-> (KernelBody rep -> KernelBody rep)
-> KernelBody rep
-> Identity (KernelBody rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> KernelBody rep -> KernelBody rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpVName :: VName -> Identity VName
mapOnSegOpVName = VName -> Identity VName
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> Identity VName)
-> (VName -> VName) -> VName -> Identity VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpLevel :: lvl -> Identity lvl
mapOnSegOpLevel = lvl -> Identity lvl
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (lvl -> Identity lvl) -> (lvl -> lvl) -> lvl -> Identity lvl
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> lvl -> lvl
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
}
instance (ASTRep rep, ASTConstraints lvl) => Rename (SegOp lvl rep) where
rename :: SegOp lvl rep -> RenameM (SegOp lvl rep)
rename SegOp lvl rep
op =
[VName] -> RenameM (SegOp lvl rep) -> RenameM (SegOp lvl rep)
forall a. [VName] -> RenameM a -> RenameM a
renameBound (Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp lvl rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
op))) (RenameM (SegOp lvl rep) -> RenameM (SegOp lvl rep))
-> RenameM (SegOp lvl rep) -> RenameM (SegOp lvl rep)
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl rep rep RenameM
-> SegOp lvl rep -> RenameM (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep RenameM
renamer SegOp lvl rep
op
where
renamer :: SegOpMapper lvl rep rep RenameM
renamer = (SubExp -> RenameM SubExp)
-> (Lambda rep -> RenameM (Lambda rep))
-> (KernelBody rep -> RenameM (KernelBody rep))
-> (VName -> RenameM VName)
-> (lvl -> RenameM lvl)
-> SegOpMapper lvl rep rep RenameM
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename Lambda rep -> RenameM (Lambda rep)
forall a. Rename a => a -> RenameM a
rename KernelBody rep -> RenameM (KernelBody rep)
forall a. Rename a => a -> RenameM a
rename VName -> RenameM VName
forall a. Rename a => a -> RenameM a
rename lvl -> RenameM lvl
forall a. Rename a => a -> RenameM a
rename
instance (ASTRep rep, FreeIn lvl) => FreeIn (SegOp lvl rep) where
freeIn' :: SegOp lvl rep -> FV
freeIn' SegOp lvl rep
e =
Names -> FV -> FV
fvBind ([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp lvl rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
e)) (FV -> FV) -> FV -> FV
forall a b. (a -> b) -> a -> b
$
(State FV (SegOp lvl rep) -> FV -> FV)
-> FV -> State FV (SegOp lvl rep) -> FV
forall a b c. (a -> b -> c) -> b -> a -> c
flip State FV (SegOp lvl rep) -> FV -> FV
forall s a. State s a -> s -> s
execState FV
forall a. Monoid a => a
mempty (State FV (SegOp lvl rep) -> FV) -> State FV (SegOp lvl rep) -> FV
forall a b. (a -> b) -> a -> b
$
SegOpMapper lvl rep rep (StateT FV Identity)
-> SegOp lvl rep -> State FV (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep (StateT FV Identity)
free SegOp lvl rep
e
where
walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = (s -> s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (s -> s -> s
forall a. Semigroup a => a -> a -> a
<> b -> s
f b
x) m () -> m b -> m b
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
x
free :: SegOpMapper lvl rep rep (StateT FV Identity)
free =
SegOpMapper
{ mapOnSegOpSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSegOpSubExp = (SubExp -> FV) -> SubExp -> StateT FV Identity SubExp
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpLambda :: Lambda rep -> StateT FV Identity (Lambda rep)
mapOnSegOpLambda = (Lambda rep -> FV) -> Lambda rep -> StateT FV Identity (Lambda rep)
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk Lambda rep -> FV
forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpBody :: KernelBody rep -> StateT FV Identity (KernelBody rep)
mapOnSegOpBody = (KernelBody rep -> FV)
-> KernelBody rep -> StateT FV Identity (KernelBody rep)
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk KernelBody rep -> FV
forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpVName :: VName -> StateT FV Identity VName
mapOnSegOpVName = (VName -> FV) -> VName -> StateT FV Identity VName
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk VName -> FV
forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpLevel :: lvl -> StateT FV Identity lvl
mapOnSegOpLevel = (lvl -> FV) -> lvl -> StateT FV Identity lvl
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk lvl -> FV
forall a. FreeIn a => a -> FV
freeIn'
}
instance (OpMetrics (Op rep)) => OpMetrics (SegOp lvl rep) where
opMetrics :: SegOp lvl rep -> MetricsM ()
opMetrics (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegMap" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
opMetrics (SegRed lvl
_ SegSpace
_ [SegBinOp rep]
reds [Type]
_ KernelBody rep
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegRed" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
(SegBinOp rep -> MetricsM ()) -> [SegBinOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Text -> MetricsM () -> MetricsM ()
inside Text
"SegBinOp" (MetricsM () -> MetricsM ())
-> (SegBinOp rep -> MetricsM ()) -> SegBinOp rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (SegBinOp rep -> Lambda rep) -> SegBinOp rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp rep]
reds
KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
opMetrics (SegScan lvl
_ SegSpace
_ [SegBinOp rep]
scans [Type]
_ KernelBody rep
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegScan" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
(SegBinOp rep -> MetricsM ()) -> [SegBinOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Text -> MetricsM () -> MetricsM ()
inside Text
"SegBinOp" (MetricsM () -> MetricsM ())
-> (SegBinOp rep -> MetricsM ()) -> SegBinOp rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (SegBinOp rep -> Lambda rep) -> SegBinOp rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp rep]
scans
KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
opMetrics (SegHist lvl
_ SegSpace
_ [HistOp rep]
ops [Type]
_ KernelBody rep
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegHist" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
(HistOp rep -> MetricsM ()) -> [HistOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (HistOp rep -> Lambda rep) -> HistOp rep -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp) [HistOp rep]
ops
KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
instance Pretty SegSpace where
pretty :: forall ann. SegSpace -> Doc ann
pretty (SegSpace VName
phys [(VName, SubExp)]
dims) =
[Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
apply
( do
(VName
i, SubExp
d) <- [(VName, SubExp)]
dims
Doc ann -> [Doc ann]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Doc ann -> [Doc ann]) -> Doc ann -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty VName
i Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"<" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
d
)
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Doc ann
"~" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty VName
phys)
instance (PrettyRep rep) => Pretty (SegBinOp rep) where
pretty :: forall ann. SegBinOp rep -> Doc ann
pretty (SegBinOp Commutativity
comm Lambda rep
lam [SubExp]
nes ShapeBase SubExp
shape) =
Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
nes)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> ShapeBase SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. ShapeBase SubExp -> Doc ann
pretty ShapeBase SubExp
shape
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
comm'
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
lam
where
comm' :: Doc ann
comm' = case Commutativity
comm of
Commutativity
Commutative -> Doc ann
"commutative "
Commutativity
Noncommutative -> Doc ann
forall a. Monoid a => a
mempty
instance (PrettyRep rep, PP.Pretty lvl) => PP.Pretty (SegOp lvl rep) where
pretty :: forall ann. SegOp lvl rep -> Doc ann
pretty (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody rep
body) =
Doc ann
"segmap"
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc ann
forall ann. lvl -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.align (SegSpace -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegSpace -> Doc ann
pretty SegSpace
space)
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
forall ann. Doc ann
PP.colon
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((Type -> Doc ann) -> [Type] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Type -> Doc ann
pretty [Type]
ts)
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (KernelBody rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. KernelBody rep -> Doc ann
pretty KernelBody rep
body)
pretty (SegRed lvl
lvl SegSpace
space [SegBinOp rep]
reds [Type]
ts KernelBody rep
body) =
Doc ann
"segred"
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc ann
forall ann. lvl -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.align (SegSpace -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegSpace -> Doc ann
pretty SegSpace
space)
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.parens ([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
PP.comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (SegBinOp rep -> Doc ann) -> [SegBinOp rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegBinOp rep -> Doc ann
pretty [SegBinOp rep]
reds)
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
forall ann. Doc ann
PP.colon
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((Type -> Doc ann) -> [Type] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Type -> Doc ann
pretty [Type]
ts)
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (KernelBody rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. KernelBody rep -> Doc ann
pretty KernelBody rep
body)
pretty (SegScan lvl
lvl SegSpace
space [SegBinOp rep]
scans [Type]
ts KernelBody rep
body) =
Doc ann
"segscan"
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc ann
forall ann. lvl -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.align (SegSpace -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegSpace -> Doc ann
pretty SegSpace
space)
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.parens ([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
PP.comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (SegBinOp rep -> Doc ann) -> [SegBinOp rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegBinOp rep -> Doc ann
pretty [SegBinOp rep]
scans)
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
forall ann. Doc ann
PP.colon
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((Type -> Doc ann) -> [Type] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Type -> Doc ann
pretty [Type]
ts)
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (KernelBody rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. KernelBody rep -> Doc ann
pretty KernelBody rep
body)
pretty (SegHist lvl
lvl SegSpace
space [HistOp rep]
ops [Type]
ts KernelBody rep
body) =
Doc ann
"seghist"
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc ann
forall ann. lvl -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.align (SegSpace -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegSpace -> Doc ann
pretty SegSpace
space)
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.parens ([Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall a. a -> [a] -> [a]
intersperse (Doc ann
forall ann. Doc ann
PP.comma Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.line) ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> Doc ann) -> [HistOp rep] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map HistOp rep -> Doc ann
forall {rep} {ann}. PrettyRep rep => HistOp rep -> Doc ann
ppOp [HistOp rep]
ops)
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
forall ann. Doc ann
PP.colon
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((Type -> Doc ann) -> [Type] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Type -> Doc ann
pretty [Type]
ts)
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (KernelBody rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. KernelBody rep -> Doc ann
pretty KernelBody rep
body)
where
ppOp :: HistOp rep -> Doc ann
ppOp (HistOp ShapeBase SubExp
w SubExp
rf [VName]
dests [SubExp]
nes ShapeBase SubExp
shape Lambda rep
op) =
ShapeBase SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. ShapeBase SubExp -> Doc ann
pretty ShapeBase SubExp
w
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
rf
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (VName -> Doc ann) -> [VName] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty [VName]
dests)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.braces ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc ann) -> [SubExp] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty [SubExp]
nes)
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> ShapeBase SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. ShapeBase SubExp -> Doc ann
pretty ShapeBase SubExp
shape
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.comma
Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Lambda rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Lambda rep -> Doc ann
pretty Lambda rep
op
instance CanBeAliased (SegOp lvl) where
addOpAliases :: forall rep.
AliasableRep rep =>
AliasTable -> SegOp lvl rep -> SegOp lvl (Aliases rep)
addOpAliases AliasTable
aliases = Identity (SegOp lvl (Aliases rep)) -> SegOp lvl (Aliases rep)
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl (Aliases rep)) -> SegOp lvl (Aliases rep))
-> (SegOp lvl rep -> Identity (SegOp lvl (Aliases rep)))
-> SegOp lvl rep
-> SegOp lvl (Aliases rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl rep (Aliases rep) Identity
-> SegOp lvl rep -> Identity (SegOp lvl (Aliases rep))
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep (Aliases rep) Identity
alias
where
alias :: SegOpMapper lvl rep (Aliases rep) Identity
alias =
(SubExp -> Identity SubExp)
-> (Lambda rep -> Identity (Lambda (Aliases rep)))
-> (KernelBody rep -> Identity (KernelBody (Aliases rep)))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl rep (Aliases rep) Identity
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
SubExp -> Identity SubExp
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
(Lambda (Aliases rep) -> Identity (Lambda (Aliases rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Aliases rep) -> Identity (Lambda (Aliases rep)))
-> (Lambda rep -> Lambda (Aliases rep))
-> Lambda rep
-> Identity (Lambda (Aliases rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases)
(KernelBody (Aliases rep) -> Identity (KernelBody (Aliases rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody (Aliases rep) -> Identity (KernelBody (Aliases rep)))
-> (KernelBody rep -> KernelBody (Aliases rep))
-> KernelBody rep
-> Identity (KernelBody (Aliases rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
aliases)
VName -> Identity VName
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
lvl -> Identity lvl
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
informKernelBody :: (Informing rep) => KernelBody rep -> KernelBody (Wise rep)
informKernelBody :: forall rep.
Informing rep =>
KernelBody rep -> KernelBody (Wise rep)
informKernelBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
forall rep.
Informing rep =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody BodyDec rep
dec (Stms rep -> Stms (Wise rep)
forall rep. Informing rep => Stms rep -> Stms (Wise rep)
informStms Stms rep
stms) [KernelResult]
res
instance CanBeWise (SegOp lvl) where
addOpWisdom :: forall rep. Informing rep => SegOp lvl rep -> SegOp lvl (Wise rep)
addOpWisdom = Identity (SegOp lvl (Wise rep)) -> SegOp lvl (Wise rep)
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl (Wise rep)) -> SegOp lvl (Wise rep))
-> (SegOp lvl rep -> Identity (SegOp lvl (Wise rep)))
-> SegOp lvl rep
-> SegOp lvl (Wise rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl rep (Wise rep) Identity
-> SegOp lvl rep -> Identity (SegOp lvl (Wise rep))
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep (Wise rep) Identity
forall {lvl}. SegOpMapper lvl rep (Wise rep) Identity
add
where
add :: SegOpMapper lvl rep (Wise rep) Identity
add =
(SubExp -> Identity SubExp)
-> (Lambda rep -> Identity (Lambda (Wise rep)))
-> (KernelBody rep -> Identity (KernelBody (Wise rep)))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl rep (Wise rep) Identity
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
SubExp -> Identity SubExp
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
(Lambda (Wise rep) -> Identity (Lambda (Wise rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Wise rep) -> Identity (Lambda (Wise rep)))
-> (Lambda rep -> Lambda (Wise rep))
-> Lambda rep
-> Identity (Lambda (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> Lambda (Wise rep)
forall rep. Informing rep => Lambda rep -> Lambda (Wise rep)
informLambda)
(KernelBody (Wise rep) -> Identity (KernelBody (Wise rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody (Wise rep) -> Identity (KernelBody (Wise rep)))
-> (KernelBody rep -> KernelBody (Wise rep))
-> KernelBody rep
-> Identity (KernelBody (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody rep -> KernelBody (Wise rep)
forall rep.
Informing rep =>
KernelBody rep -> KernelBody (Wise rep)
informKernelBody)
VName -> Identity VName
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
lvl -> Identity lvl
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
instance (ASTRep rep) => ST.IndexOp (SegOp lvl rep) where
indexOp :: forall rep.
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> SegOp lvl rep -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k (SegMap lvl
_ SegSpace
space [Type]
_ KernelBody rep
kbody) [TPrimExp Int64 VName]
is = do
Returns ResultManifest
ResultMaySimplify Certs
_ SubExp
se <- Int -> [KernelResult] -> Maybe KernelResult
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
k ([KernelResult] -> Maybe KernelResult)
-> [KernelResult] -> Maybe KernelResult
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
is
let idx_table :: Map VName Indexed
idx_table = [(VName, Indexed)] -> Map VName Indexed
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Indexed)] -> Map VName Indexed)
-> [(VName, Indexed)] -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ [VName] -> [Indexed] -> [(VName, Indexed)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids ([Indexed] -> [(VName, Indexed)])
-> [Indexed] -> [(VName, Indexed)]
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> Indexed)
-> [TPrimExp Int64 VName] -> [Indexed]
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> PrimExp VName -> Indexed
ST.Indexed Certs
forall a. Monoid a => a
mempty (PrimExp VName -> Indexed)
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> Indexed
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) [TPrimExp Int64 VName]
is
idx_table' :: Map VName Indexed
idx_table' = (Map VName Indexed -> Stm rep -> Map VName Indexed)
-> Map VName Indexed -> Seq (Stm rep) -> Map VName Indexed
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map VName Indexed -> Stm rep -> Map VName Indexed
expandIndexedTable Map VName Indexed
idx_table (Seq (Stm rep) -> Map VName Indexed)
-> Seq (Stm rep) -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> Seq (Stm rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
kbody
case SubExp
se of
Var VName
v -> VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
idx_table'
SubExp
_ -> Maybe Indexed
forall a. Maybe a
Nothing
where
([VName]
gtids, [SubExp]
_) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
excess_is :: [TPrimExp Int64 VName]
excess_is = Int -> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. Int -> [a] -> [a]
drop ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) [TPrimExp Int64 VName]
is
expandIndexedTable :: Map VName Indexed -> Stm rep -> Map VName Indexed
expandIndexedTable Map VName Indexed
table Stm rep
stm
| [VName
v] <- Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName]) -> Pat (LetDec rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
Just (PrimExp VName
pe, Certs
cs) <-
WriterT Certs Maybe (PrimExp VName) -> Maybe (PrimExp VName, Certs)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certs Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certs))
-> WriterT Certs Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certs)
forall a b. (a -> b) -> a -> b
$ (VName -> WriterT Certs Maybe (PrimExp VName))
-> Exp rep -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) rep v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table) (Exp rep -> WriterT Certs Maybe (PrimExp VName))
-> Exp rep -> WriterT Certs Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Certs -> PrimExp VName -> Indexed
ST.Indexed (Stm rep -> Certs
forall rep. Stm rep -> Certs
stmCerts Stm rep
stm Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) PrimExp VName
pe) Map VName Indexed
table
| [VName
v] <- Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName]) -> Pat (LetDec rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
BasicOp (Index VName
arr Slice SubExp
slice) <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
[SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
excess_is,
VName
arr VName -> SymbolTable rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.available` SymbolTable rep
vtable,
Just (Slice (PrimExp VName)
slice', Certs
cs) <- Map VName Indexed
-> Slice SubExp -> Maybe (Slice (PrimExp VName), Certs)
asPrimExpSlice Map VName Indexed
table Slice SubExp
slice =
let idx :: Indexed
idx =
Certs -> VName -> [TPrimExp Int64 VName] -> Indexed
ST.IndexedArray
(Stm rep -> Certs
forall rep. Stm rep -> Certs
stmCerts Stm rep
stm Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs)
VName
arr
(Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((PrimExp VName -> TPrimExp Int64 VName)
-> Slice (PrimExp VName) -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> Slice a -> Slice b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 Slice (PrimExp VName)
slice') [TPrimExp Int64 VName]
excess_is)
in VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Indexed
idx Map VName Indexed
table
| Bool
otherwise =
Map VName Indexed
table
asPrimExpSlice :: Map VName Indexed
-> Slice SubExp -> Maybe (Slice (PrimExp VName), Certs)
asPrimExpSlice Map VName Indexed
table =
WriterT Certs Maybe (Slice (PrimExp VName))
-> Maybe (Slice (PrimExp VName), Certs)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certs Maybe (Slice (PrimExp VName))
-> Maybe (Slice (PrimExp VName), Certs))
-> (Slice SubExp -> WriterT Certs Maybe (Slice (PrimExp VName)))
-> Slice SubExp
-> Maybe (Slice (PrimExp VName), Certs)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SubExp -> WriterT Certs Maybe (PrimExp VName))
-> Slice SubExp -> WriterT Certs Maybe (Slice (PrimExp VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Slice a -> f (Slice b)
traverse ((VName -> WriterT Certs Maybe (PrimExp VName))
-> SubExp -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM (Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table))
asPrimExp :: Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table VName
v
| Just (ST.Indexed Certs
cs PrimExp VName
e) <- VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
table = Certs -> WriterT Certs Maybe ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certs
cs WriterT Certs Maybe ()
-> WriterT Certs Maybe (PrimExp VName)
-> WriterT Certs Maybe (PrimExp VName)
forall a b.
WriterT Certs Maybe a
-> WriterT Certs Maybe b -> WriterT Certs Maybe b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall a. a -> WriterT Certs Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimExp VName
e
| Just (Prim PrimType
pt) <- VName -> SymbolTable rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable =
PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall a. a -> WriterT Certs Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimExp VName -> WriterT Certs Maybe (PrimExp VName))
-> PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
| Bool
otherwise = Maybe (PrimExp VName) -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => m a -> WriterT Certs m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe (PrimExp VName)
forall a. Maybe a
Nothing
indexOp SymbolTable rep
_ Int
_ SegOp lvl rep
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing
instance (ASTConstraints lvl) => IsOp (SegOp lvl) where
cheapOp :: forall rep. ASTRep rep => SegOp lvl rep -> Bool
cheapOp SegOp lvl rep
_ = Bool
False
safeOp :: forall rep. ASTRep rep => SegOp lvl rep -> Bool
safeOp SegOp lvl rep
_ = Bool
True
opDependencies :: forall rep. ASTRep rep => SegOp lvl rep -> [Names]
opDependencies SegOp lvl rep
op = Int -> Names -> [Names]
forall a. Int -> a -> [a]
replicate ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegOp lvl rep -> [Type]
forall lvl rep. SegOp lvl rep -> [Type]
segOpType SegOp lvl rep
op)) (SegOp lvl rep -> Names
forall a. FreeIn a => a -> Names
freeIn SegOp lvl rep
op)
instance Engine.Simplifiable SegSpace where
simplify :: forall rep. SimplifiableRep rep => SegSpace -> SimpleM rep SegSpace
simplify (SegSpace VName
phys [(VName, SubExp)]
dims) =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
phys ([(VName, SubExp)] -> SegSpace)
-> SimpleM rep [(VName, SubExp)] -> SimpleM rep SegSpace
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, SubExp) -> SimpleM rep (VName, SubExp))
-> [(VName, SubExp)] -> SimpleM rep [(VName, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((SubExp -> SimpleM rep SubExp)
-> (VName, SubExp) -> SimpleM rep (VName, SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> (VName, a) -> f (VName, b)
traverse SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify) [(VName, SubExp)]
dims
instance Engine.Simplifiable KernelResult where
simplify :: forall rep.
SimplifiableRep rep =>
KernelResult -> SimpleM rep KernelResult
simplify (Returns ResultManifest
manifest Certs
cs SubExp
what) =
ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest (Certs -> SubExp -> KernelResult)
-> SimpleM rep Certs -> SimpleM rep (SubExp -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall rep. SimplifiableRep rep => Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs SimpleM rep (SubExp -> KernelResult)
-> SimpleM rep SubExp -> SimpleM rep KernelResult
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
what
simplify (WriteReturns Certs
cs VName
a [(Slice SubExp, SubExp)]
res) =
Certs -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns
(Certs -> VName -> [(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM rep Certs
-> SimpleM rep (VName -> [(Slice SubExp, SubExp)] -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall rep. SimplifiableRep rep => Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs
SimpleM rep (VName -> [(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM rep VName
-> SimpleM rep ([(Slice SubExp, SubExp)] -> KernelResult)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
a
SimpleM rep ([(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM rep [(Slice SubExp, SubExp)] -> SimpleM rep KernelResult
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(Slice SubExp, SubExp)] -> SimpleM rep [(Slice SubExp, SubExp)]
forall rep.
SimplifiableRep rep =>
[(Slice SubExp, SubExp)] -> SimpleM rep [(Slice SubExp, SubExp)]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(Slice SubExp, SubExp)]
res
simplify (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
what) =
Certs -> [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns (Certs -> [(SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep Certs
-> SimpleM rep ([(SubExp, SubExp)] -> VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall rep. SimplifiableRep rep => Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs SimpleM rep ([(SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep [(SubExp, SubExp)]
-> SimpleM rep (VName -> KernelResult)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(SubExp, SubExp)] -> SimpleM rep [(SubExp, SubExp)]
forall rep.
SimplifiableRep rep =>
[(SubExp, SubExp)] -> SimpleM rep [(SubExp, SubExp)]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(SubExp, SubExp)]
dims SimpleM rep (VName -> KernelResult)
-> SimpleM rep VName -> SimpleM rep KernelResult
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
what
simplify (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
what) =
Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
(Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep Certs
-> SimpleM
rep ([(SubExp, SubExp, SubExp)] -> VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall rep. SimplifiableRep rep => Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs
SimpleM rep ([(SubExp, SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep [(SubExp, SubExp, SubExp)]
-> SimpleM rep (VName -> KernelResult)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(SubExp, SubExp, SubExp)]
-> SimpleM rep [(SubExp, SubExp, SubExp)]
forall rep.
SimplifiableRep rep =>
[(SubExp, SubExp, SubExp)]
-> SimpleM rep [(SubExp, SubExp, SubExp)]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(SubExp, SubExp, SubExp)]
dims_n_tiles
SimpleM rep (VName -> KernelResult)
-> SimpleM rep VName -> SimpleM rep KernelResult
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
what
mkWiseKernelBody ::
(Informing rep) =>
BodyDec rep ->
Stms (Wise rep) ->
[KernelResult] ->
KernelBody (Wise rep)
mkWiseKernelBody :: forall rep.
Informing rep =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody BodyDec rep
dec Stms (Wise rep)
stms [KernelResult]
res =
let Body BodyDec (Wise rep)
dec' Stms (Wise rep)
_ Result
_ = BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
forall rep.
Informing rep =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
mkWiseBody BodyDec rep
dec Stms (Wise rep)
stms (Result -> Body (Wise rep)) -> Result -> Body (Wise rep)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
res_vs
in BodyDec (Wise rep)
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Wise rep)
dec' Stms (Wise rep)
stms [KernelResult]
res
where
res_vs :: [SubExp]
res_vs = (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
res
mkKernelBodyM ::
(MonadBuilder m) =>
Stms (Rep m) ->
[KernelResult] ->
m (KernelBody (Rep m))
mkKernelBodyM :: forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms (Rep m)
stms [KernelResult]
kres = do
Body BodyDec (Rep m)
dec' Stms (Rep m)
_ Result
_ <- Stms (Rep m) -> Result -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
stms (Result -> m (Body (Rep m))) -> Result -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
res_ses
KernelBody (Rep m) -> m (KernelBody (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody (Rep m) -> m (KernelBody (Rep m)))
-> KernelBody (Rep m) -> m (KernelBody (Rep m))
forall a b. (a -> b) -> a -> b
$ BodyDec (Rep m)
-> Stms (Rep m) -> [KernelResult] -> KernelBody (Rep m)
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Rep m)
dec' Stms (Rep m)
stms [KernelResult]
kres
where
res_ses :: [SubExp]
res_ses = (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
kres
simplifyKernelBody ::
(Engine.SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace ->
KernelBody (Wise rep) ->
Engine.SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody :: forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space (KernelBody BodyDec (Wise rep)
_ Stms (Wise rep)
stms [KernelResult]
res) = do
BlockPred (Wise rep)
par_blocker <- (Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep))
forall {k} (rep :: k) a. (Env rep -> a) -> SimpleM rep a
Engine.asksEngineEnv ((Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep)))
-> (Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep))
forall a b. (a -> b) -> a -> b
$ HoistBlockers rep -> BlockPred (Wise rep)
forall {k} (rep :: k). HoistBlockers rep -> BlockPred (Wise rep)
Engine.blockHoistPar (HoistBlockers rep -> BlockPred (Wise rep))
-> (Env rep -> HoistBlockers rep)
-> Env rep
-> BlockPred (Wise rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Env rep -> HoistBlockers rep
forall {k} (rep :: k). Env rep -> HoistBlockers rep
Engine.envHoistBlockers
let blocker :: BlockPred (Wise rep)
blocker =
Names -> BlockPred (Wise rep)
forall rep. ASTRep rep => Names -> BlockPred rep
Engine.hasFree Names
bound_here
BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. BlockPred rep
Engine.isOp
BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
par_blocker
BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. BlockPred rep
Engine.isConsumed
BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. Aliased rep => BlockPred rep
Engine.isConsuming
BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. SimplifiableRep rep => BlockPred (Wise rep)
Engine.isDeviceMigrated
([KernelResult]
body_res, Stms (Wise rep)
body_stms, Stms (Wise rep)
hoisted) <-
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable ((SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep))
-> [VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((SymbolTable (Wise rep) -> VName -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> VName -> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep. VName -> SymbolTable rep -> SymbolTable rep
ST.consume)) ((KernelResult -> [VName]) -> [KernelResult] -> [VName]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap KernelResult -> [VName]
consumedInResult [KernelResult]
res))
(SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable)
(SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {ST.simplifyMemory = True})
(SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop
(SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep ([KernelResult], UsageTable)
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep (a, UsageTable)
-> SimpleM rep (a, Stms (Wise rep), Stms (Wise rep))
Engine.blockIf BlockPred (Wise rep)
blocker Stms (Wise rep)
stms
(SimpleM rep ([KernelResult], UsageTable)
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], UsageTable)
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ do
[KernelResult]
res' <-
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep [KernelResult] -> SimpleM rep [KernelResult]
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (Names -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep. Names -> SymbolTable rep -> SymbolTable rep
ST.hideCertified (Names -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> Names -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo (Wise rep)) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo (Wise rep)) -> [VName])
-> Map VName (NameInfo (Wise rep)) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms (Wise rep) -> Map VName (NameInfo (Wise rep))
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms (Wise rep)
stms) (SimpleM rep [KernelResult] -> SimpleM rep [KernelResult])
-> SimpleM rep [KernelResult] -> SimpleM rep [KernelResult]
forall a b. (a -> b) -> a -> b
$
(KernelResult -> SimpleM rep KernelResult)
-> [KernelResult] -> SimpleM rep [KernelResult]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM KernelResult -> SimpleM rep KernelResult
forall rep.
SimplifiableRep rep =>
KernelResult -> SimpleM rep KernelResult
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [KernelResult]
res
([KernelResult], UsageTable)
-> SimpleM rep ([KernelResult], UsageTable)
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([KernelResult]
res', Names -> UsageTable
UT.usages (Names -> UsageTable) -> Names -> UsageTable
forall a b. (a -> b) -> a -> b
$ [KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn [KernelResult]
res')
(KernelBody (Wise rep), Stms (Wise rep))
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
forall rep.
Informing rep =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody () Stms (Wise rep)
body_stms [KernelResult]
body_res, Stms (Wise rep)
hoisted)
where
scope_vtable :: SymbolTable (Wise rep)
scope_vtable = SegSpace -> SymbolTable (Wise rep)
forall rep. ASTRep rep => SegSpace -> SymbolTable rep
segSpaceSymbolTable SegSpace
space
bound_here :: Names
bound_here = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
consumedInResult :: KernelResult -> [VName]
consumedInResult (WriteReturns Certs
_ VName
arr [(Slice SubExp, SubExp)]
_) =
[VName
arr]
consumedInResult KernelResult
_ =
[]
simplifyLambda ::
(Engine.SimplifiableRep rep) =>
Names ->
Lambda (Wise rep) ->
Engine.SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda :: forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda Names
bound = SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.blockMigrated (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> (Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
bound
segSpaceSymbolTable :: (ASTRep rep) => SegSpace -> ST.SymbolTable rep
segSpaceSymbolTable :: forall rep. ASTRep rep => SegSpace -> SymbolTable rep
segSpaceSymbolTable (SegSpace VName
flat [(VName, SubExp)]
gtids_and_dims) =
(SymbolTable rep -> (VName, SubExp) -> SymbolTable rep)
-> SymbolTable rep -> [(VName, SubExp)] -> SymbolTable rep
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
forall {rep}.
ASTRep rep =>
SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
f (Scope rep -> SymbolTable rep
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope (Scope rep -> SymbolTable rep) -> Scope rep -> SymbolTable rep
forall a b. (a -> b) -> a -> b
$ VName -> NameInfo rep -> Scope rep
forall k a. k -> a -> Map k a
M.singleton VName
flat (NameInfo rep -> Scope rep) -> NameInfo rep -> Scope rep
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo rep
forall rep. IntType -> NameInfo rep
IndexName IntType
Int64) [(VName, SubExp)]
gtids_and_dims
where
f :: SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
f SymbolTable rep
vtable (VName
gtid, SubExp
dim) = VName -> IntType -> SubExp -> SymbolTable rep -> SymbolTable rep
forall rep.
ASTRep rep =>
VName -> IntType -> SubExp -> SymbolTable rep -> SymbolTable rep
ST.insertLoopVar VName
gtid IntType
Int64 SubExp
dim SymbolTable rep
vtable
simplifySegBinOp ::
(Engine.SimplifiableRep rep) =>
VName ->
SegBinOp (Wise rep) ->
Engine.SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp :: forall rep.
SimplifiableRep rep =>
VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp VName
phys_id (SegBinOp Commutativity
comm Lambda (Wise rep)
lam [SubExp]
nes ShapeBase SubExp
shape) = do
(Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <-
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {ST.simplifyMemory = True}) (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda (VName -> Names
oneName VName
phys_id) Lambda (Wise rep)
lam
ShapeBase SubExp
shape' <- ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall rep.
SimplifiableRep rep =>
ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
shape
[SubExp]
nes' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
(SegBinOp (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Commutativity
-> Lambda (Wise rep)
-> [SubExp]
-> ShapeBase SubExp
-> SegBinOp (Wise rep)
forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm Lambda (Wise rep)
lam' [SubExp]
nes' ShapeBase SubExp
shape', Stms (Wise rep)
hoisted)
simplifySegOp ::
( Engine.SimplifiableRep rep,
BodyDec rep ~ (),
Engine.Simplifiable lvl
) =>
SegOp lvl (Wise rep) ->
Engine.SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp :: forall rep lvl.
(SimplifiableRep rep, BodyDec rep ~ (), Simplifiable lvl) =>
SegOp lvl (Wise rep)
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Wise rep)
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall rep.
SimplifiableRep rep =>
(lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
(KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
(SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( lvl
-> SegSpace
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl' SegSpace
space' [Type]
ts' KernelBody (Wise rep)
kbody',
Stms (Wise rep)
body_hoisted
)
simplifySegOp (SegRed lvl
lvl SegSpace
space [SegBinOp (Wise rep)]
reds [Type]
ts KernelBody (Wise rep)
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall rep.
SimplifiableRep rep =>
(lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
([SegBinOp (Wise rep)]
reds', [Stms (Wise rep)]
reds_hoisted) <-
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) (SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
(SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep)))
-> [SegBinOp (Wise rep)]
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM (VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp (SegSpace -> VName
segFlat SegSpace
space)) [SegBinOp (Wise rep)]
reds
(KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
(SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( lvl
-> SegSpace
-> [SegBinOp (Wise rep)]
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl' SegSpace
space' [SegBinOp (Wise rep)]
reds' [Type]
ts' KernelBody (Wise rep)
kbody',
[Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
reds_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
)
where
scope :: Scope (Wise rep)
scope = SegSpace -> Scope (Wise rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
scope_vtable :: SymbolTable (Wise rep)
scope_vtable = Scope (Wise rep) -> SymbolTable (Wise rep)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
simplifySegOp (SegScan lvl
lvl SegSpace
space [SegBinOp (Wise rep)]
scans [Type]
ts KernelBody (Wise rep)
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall rep.
SimplifiableRep rep =>
(lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
([SegBinOp (Wise rep)]
scans', [Stms (Wise rep)]
scans_hoisted) <-
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) (SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
(SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep)))
-> [SegBinOp (Wise rep)]
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM (VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp (SegSpace -> VName
segFlat SegSpace
space)) [SegBinOp (Wise rep)]
scans
(KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
(SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( lvl
-> SegSpace
-> [SegBinOp (Wise rep)]
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl' SegSpace
space' [SegBinOp (Wise rep)]
scans' [Type]
ts' KernelBody (Wise rep)
kbody',
[Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
scans_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
)
where
scope :: Scope (Wise rep)
scope = SegSpace -> Scope (Wise rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
scope_vtable :: SymbolTable (Wise rep)
scope_vtable = Scope (Wise rep) -> SymbolTable (Wise rep)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
simplifySegOp (SegHist lvl
lvl SegSpace
space [HistOp (Wise rep)]
ops [Type]
ts KernelBody (Wise rep)
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall rep.
SimplifiableRep rep =>
(lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable ((SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep))
-> [VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep. VName -> SymbolTable rep -> SymbolTable rep
ST.consume) ([VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> [VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b. (a -> b) -> a -> b
$ (HistOp (Wise rep) -> [VName]) -> [HistOp (Wise rep)] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp (Wise rep) -> [VName]
forall rep. HistOp rep -> [VName]
histDest [HistOp (Wise rep)]
ops) (SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep)))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ do
([HistOp (Wise rep)]
ops', [Stms (Wise rep)]
ops_hoisted) <- ([(HistOp (Wise rep), Stms (Wise rep))]
-> ([HistOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(HistOp (Wise rep), Stms (Wise rep))]
-> ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)]))
-> ((HistOp (Wise rep)
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))])
-> (HistOp (Wise rep)
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [HistOp (Wise rep)]
-> (HistOp (Wise rep)
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Wise rep)]
ops ((HistOp (Wise rep)
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)]))
-> (HistOp (Wise rep)
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
\(HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
dims Lambda (Wise rep)
lam) -> do
ShapeBase SubExp
w' <- ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall rep.
SimplifiableRep rep =>
ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
w
SubExp
rf' <- SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
rf
[VName]
arrs' <- [VName] -> SimpleM rep [VName]
forall rep. SimplifiableRep rep => [VName] -> SimpleM rep [VName]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
arrs
[SubExp]
nes' <- [SubExp] -> SimpleM rep [SubExp]
forall rep. SimplifiableRep rep => [SubExp] -> SimpleM rep [SubExp]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
ShapeBase SubExp
dims' <- ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall rep.
SimplifiableRep rep =>
ShapeBase SubExp -> SimpleM rep (ShapeBase SubExp)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
dims
(Lambda (Wise rep)
lam', Stms (Wise rep)
op_hoisted) <-
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {ST.simplifyMemory = True}) (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda (VName -> Names
oneName (SegSpace -> VName
segFlat SegSpace
space)) Lambda (Wise rep)
lam
(HistOp (Wise rep), Stms (Wise rep))
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda (Wise rep)
-> HistOp (Wise rep)
forall rep.
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda rep
-> HistOp rep
HistOp ShapeBase SubExp
w' SubExp
rf' [VName]
arrs' [SubExp]
nes' ShapeBase SubExp
dims' Lambda (Wise rep)
lam',
Stms (Wise rep)
op_hoisted
)
(KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
(SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( lvl
-> SegSpace
-> [HistOp (Wise rep)]
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl' SegSpace
space' [HistOp (Wise rep)]
ops' [Type]
ts' KernelBody (Wise rep)
kbody',
[Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
ops_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
)
where
scope :: Scope (Wise rep)
scope = SegSpace -> Scope (Wise rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
scope_vtable :: SymbolTable (Wise rep)
scope_vtable = Scope (Wise rep) -> SymbolTable (Wise rep)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
class HasSegOp rep where
type SegOpLevel rep
asSegOp :: Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
segOp :: SegOp (SegOpLevel rep) rep -> Op rep
segOpRules ::
(HasSegOp rep, BuilderOps rep, Buildable rep, Aliased rep) =>
RuleBook rep
segOpRules :: forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep, Aliased rep) =>
RuleBook rep
segOpRules =
[TopDownRule rep] -> [BottomUpRule rep] -> RuleBook rep
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [RuleOp rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp rep (TopDown rep)
forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
TopDownRuleOp rep
segOpRuleTopDown] [RuleOp rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp rep (BottomUp rep)
forall rep.
(HasSegOp rep, BuilderOps rep, Aliased rep) =>
BottomUpRuleOp rep
segOpRuleBottomUp]
segOpRuleTopDown ::
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
TopDownRuleOp rep
segOpRuleTopDown :: forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
TopDownRuleOp rep
segOpRuleTopDown TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec Op rep
op
| Just SegOp (SegOpLevel rep) rep
op' <- Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
forall rep.
HasSegOp rep =>
Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
asSegOp Op rep
op =
TopDown rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
SymbolTable rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
topDownSegOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
op'
| Bool
otherwise =
Rule rep
forall rep. Rule rep
Skip
segOpRuleBottomUp ::
(HasSegOp rep, BuilderOps rep, Aliased rep) =>
BottomUpRuleOp rep
segOpRuleBottomUp :: forall rep.
(HasSegOp rep, BuilderOps rep, Aliased rep) =>
BottomUpRuleOp rep
segOpRuleBottomUp BottomUp rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec Op rep
op
| Just SegOp (SegOpLevel rep) rep
op' <- Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
forall rep.
HasSegOp rep =>
Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
asSegOp Op rep
op =
BottomUp rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
forall rep.
(Aliased rep, HasSegOp rep, BuilderOps rep) =>
(SymbolTable rep, UsageTable)
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
bottomUpSegOp BottomUp rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
op'
| Bool
otherwise =
Rule rep
forall rep. Rule rep
Skip
topDownSegOp ::
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
ST.SymbolTable rep ->
Pat (LetDec rep) ->
StmAux (ExpDec rep) ->
SegOp (SegOpLevel rep) rep ->
Rule rep
topDownSegOp :: forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
SymbolTable rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
topDownSegOp SymbolTable rep
vtable (Pat [PatElem (LetDec rep)]
kpes) StmAux (ExpDec rep)
dec (SegMap SegOpLevel rep
lvl SegSpace
space [Type]
ts (KernelBody BodyDec rep
_ Stms rep
kstms [KernelResult]
kres)) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
([Type]
ts', [PatElem (LetDec rep)]
kpes', [KernelResult]
kres') <-
[(Type, PatElem (LetDec rep), KernelResult)]
-> ([Type], [PatElem (LetDec rep)], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Type, PatElem (LetDec rep), KernelResult)]
-> ([Type], [PatElem (LetDec rep)], [KernelResult]))
-> RuleM rep [(Type, PatElem (LetDec rep), KernelResult)]
-> RuleM rep ([Type], [PatElem (LetDec rep)], [KernelResult])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool)
-> [(Type, PatElem (LetDec rep), KernelResult)]
-> RuleM rep [(Type, PatElem (LetDec rep), KernelResult)]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool
checkForInvarianceResult ([Type]
-> [PatElem (LetDec rep)]
-> [KernelResult]
-> [(Type, PatElem (LetDec rep), KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
ts [PatElem (LetDec rep)]
kpes [KernelResult]
kres)
Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([KernelResult]
kres [KernelResult] -> [KernelResult] -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult]
kres') RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify
KernelBody rep
kbody <- Stms (Rep (RuleM rep))
-> [KernelResult] -> RuleM rep (KernelBody (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms rep
Stms (Rep (RuleM rep))
kstms [KernelResult]
kres'
Stm (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (RuleM rep)) -> RuleM rep ())
-> Stm (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> StmAux (ExpDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep))
-> Stm (Rep (RuleM rep))
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
kpes') StmAux (ExpDec rep)
StmAux (ExpDec (Rep (RuleM rep)))
dec (Exp (Rep (RuleM rep)) -> Stm (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> Stm (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)))
-> Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> Op rep)
-> SegOp (SegOpLevel rep) rep -> Op rep
forall a b. (a -> b) -> a -> b
$ SegOpLevel rep
-> SegSpace
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space [Type]
ts' KernelBody rep
kbody
where
isInvariant :: SubExp -> Bool
isInvariant Constant {} = Bool
True
isInvariant (Var VName
v) = Maybe (Entry rep) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry rep) -> Bool) -> Maybe (Entry rep) -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> SymbolTable rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
v SymbolTable rep
vtable
checkForInvarianceResult :: (Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool
checkForInvarianceResult (Type
_, PatElem (LetDec rep)
pe, Returns ResultManifest
rm Certs
cs SubExp
se)
| Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty,
ResultManifest
rm ResultManifest -> ResultManifest -> Bool
forall a. Eq a => a -> a -> Bool
== ResultManifest
ResultMaySimplify,
SubExp -> Bool
isInvariant SubExp
se = do
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space) SubExp
se
Bool -> RuleM rep Bool
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
checkForInvarianceResult (Type, PatElem (LetDec rep), KernelResult)
_ =
Bool -> RuleM rep Bool
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
topDownSegOp SymbolTable rep
_ (Pat [PatElem (LetDec rep)]
pes) StmAux (ExpDec rep)
_ (SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
ts KernelBody rep
kbody)
| [SegBinOp rep] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegBinOp rep]
ops Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
[[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings <-
((SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
-> (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
-> Bool)
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
-> (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
-> Bool
forall {rep} {b} {rep} {b}.
(SegBinOp rep, b) -> (SegBinOp rep, b) -> Bool
sameShape ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> [[(SegBinOp rep,
[(PatElem (LetDec rep), Type, KernelResult)])]])
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
forall a b. (a -> b) -> a -> b
$
[SegBinOp rep]
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp rep]
ops ([[(PatElem (LetDec rep), Type, KernelResult)]]
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])])
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
forall a b. (a -> b) -> a -> b
$
[Int]
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp rep -> Int) -> [SegBinOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp rep -> [SubExp]) -> SegBinOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp rep]
ops) ([(PatElem (LetDec rep), Type, KernelResult)]
-> [[(PatElem (LetDec rep), Type, KernelResult)]])
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
forall a b. (a -> b) -> a -> b
$
[PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> [(PatElem (LetDec rep), Type, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
red_pes [Type]
red_ts [KernelResult]
red_res,
([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Bool)
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
-> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1) (Int -> Bool)
-> ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Int)
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
let ([SegBinOp rep]
ops', [[(PatElem (LetDec rep), Type, KernelResult)]]
aux) = [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> ([SegBinOp rep], [[(PatElem (LetDec rep), Type, KernelResult)]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> ([SegBinOp rep],
[[(PatElem (LetDec rep), Type, KernelResult)]]))
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> ([SegBinOp rep], [[(PatElem (LetDec rep), Type, KernelResult)]])
forall a b. (a -> b) -> a -> b
$ ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Maybe
(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)]))
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Maybe
(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
forall {rep} {a}.
Buildable rep =>
[(SegBinOp rep, [a])] -> Maybe (SegBinOp rep, [a])
combineOps [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings
([PatElem (LetDec rep)]
red_pes', [Type]
red_ts', [KernelResult]
red_res') = [(PatElem (LetDec rep), Type, KernelResult)]
-> ([PatElem (LetDec rep)], [Type], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(PatElem (LetDec rep), Type, KernelResult)]
-> ([PatElem (LetDec rep)], [Type], [KernelResult]))
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> ([PatElem (LetDec rep)], [Type], [KernelResult])
forall a b. (a -> b) -> a -> b
$ [[(PatElem (LetDec rep), Type, KernelResult)]]
-> [(PatElem (LetDec rep), Type, KernelResult)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(PatElem (LetDec rep), Type, KernelResult)]]
aux
pes' :: [PatElem (LetDec rep)]
pes' = [PatElem (LetDec rep)]
red_pes' [PatElem (LetDec rep)]
-> [PatElem (LetDec rep)] -> [PatElem (LetDec rep)]
forall a. [a] -> [a] -> [a]
++ [PatElem (LetDec rep)]
map_pes
ts' :: [Type]
ts' = [Type]
red_ts' [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
map_ts
kbody' :: KernelBody rep
kbody' = KernelBody rep
kbody {kernelBodyResult = red_res' ++ map_res}
Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pes') (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)))
-> Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> Op rep)
-> SegOp (SegOpLevel rep) rep -> Op rep
forall a b. (a -> b) -> a -> b
$ SegOpLevel rep
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops' [Type]
ts' KernelBody rep
kbody'
where
([PatElem (LetDec rep)]
red_pes, [PatElem (LetDec rep)]
map_pes) = Int
-> [PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) [PatElem (LetDec rep)]
pes
([Type]
red_ts, [Type]
map_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) [Type]
ts
([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
sameShape :: (SegBinOp rep, b) -> (SegBinOp rep, b) -> Bool
sameShape (SegBinOp rep
op1, b
_) (SegBinOp rep
op2, b
_) =
SegBinOp rep -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op1 ShapeBase SubExp -> ShapeBase SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SegBinOp rep -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op2
Bool -> Bool -> Bool
&& ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank (SegBinOp rep -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op1) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
combineOps :: [(SegBinOp rep, [a])] -> Maybe (SegBinOp rep, [a])
combineOps [] = Maybe (SegBinOp rep, [a])
forall a. Maybe a
Nothing
combineOps ((SegBinOp rep, [a])
x : [(SegBinOp rep, [a])]
xs) = (SegBinOp rep, [a]) -> Maybe (SegBinOp rep, [a])
forall a. a -> Maybe a
Just ((SegBinOp rep, [a]) -> Maybe (SegBinOp rep, [a]))
-> (SegBinOp rep, [a]) -> Maybe (SegBinOp rep, [a])
forall a b. (a -> b) -> a -> b
$ ((SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a]))
-> (SegBinOp rep, [a])
-> [(SegBinOp rep, [a])]
-> (SegBinOp rep, [a])
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
forall {rep} {a}.
Buildable rep =>
(SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
combine (SegBinOp rep, [a])
x [(SegBinOp rep, [a])]
xs
combine :: (SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
combine (SegBinOp rep
op1, [a]
op1_aux) (SegBinOp rep
op2, [a]
op2_aux) =
let lam1 :: Lambda rep
lam1 = SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op1
lam2 :: Lambda rep
lam2 = SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op2
([Param (LParamInfo rep)]
op1_xparams, [Param (LParamInfo rep)]
op1_yparams) =
Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op1)) ([Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam1
([Param (LParamInfo rep)]
op2_xparams, [Param (LParamInfo rep)]
op2_yparams) =
Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op2)) ([Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam2
lam :: Lambda rep
lam =
Lambda
{ lambdaParams :: [Param (LParamInfo rep)]
lambdaParams =
[Param (LParamInfo rep)]
op1_xparams
[Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op2_xparams
[Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op1_yparams
[Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op2_yparams,
lambdaReturnType :: [Type]
lambdaReturnType = Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam1 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam2,
lambdaBody :: Body rep
lambdaBody =
Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam1) Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam2)) (Result -> Body rep) -> Result -> Body rep
forall a b. (a -> b) -> a -> b
$
Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam1) Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam2)
}
in ( SegBinOp
{ segBinOpComm :: Commutativity
segBinOpComm = SegBinOp rep -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm SegBinOp rep
op1 Commutativity -> Commutativity -> Commutativity
forall a. Semigroup a => a -> a -> a
<> SegBinOp rep -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm SegBinOp rep
op2,
segBinOpLambda :: Lambda rep
segBinOpLambda = Lambda rep
lam,
segBinOpNeutral :: [SubExp]
segBinOpNeutral = SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op2,
segBinOpShape :: ShapeBase SubExp
segBinOpShape = SegBinOp rep -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op1
},
[a]
op1_aux [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
op2_aux
)
topDownSegOp SymbolTable rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ SegOp (SegOpLevel rep) rep
_ = Rule rep
forall rep. Rule rep
Skip
segOpGuts ::
SegOp (SegOpLevel rep) rep ->
( [Type],
KernelBody rep,
Int,
[Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
)
segOpGuts :: forall rep.
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
[Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
segOpGuts (SegMap SegOpLevel rep
lvl SegSpace
space [Type]
kts KernelBody rep
body) =
([Type]
kts, KernelBody rep
body, Int
0, SegOpLevel rep
-> SegSpace
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space)
segOpGuts (SegScan SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
kts KernelBody rep
body) =
([Type]
kts, KernelBody rep
body, [SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops, SegOpLevel rep
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops)
segOpGuts (SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
kts KernelBody rep
body) =
([Type]
kts, KernelBody rep
body, [SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops, SegOpLevel rep
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops)
segOpGuts (SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops [Type]
kts KernelBody rep
body) =
([Type]
kts, KernelBody rep
body, [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (HistOp rep -> [VName]) -> HistOp rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp rep]
ops, SegOpLevel rep
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops)
bottomUpSegOp ::
(Aliased rep, HasSegOp rep, BuilderOps rep) =>
(ST.SymbolTable rep, UT.UsageTable) ->
Pat (LetDec rep) ->
StmAux (ExpDec rep) ->
SegOp (SegOpLevel rep) rep ->
Rule rep
bottomUpSegOp :: forall rep.
(Aliased rep, HasSegOp rep, BuilderOps rep) =>
(SymbolTable rep, UsageTable)
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
bottomUpSegOp (SymbolTable rep
_vtable, UsageTable
used) (Pat [PatElem (LetDec rep)]
kpes) StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
segop
| ([Int]
_, [PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres') <- [(Int, PatElem (LetDec rep), Type, KernelResult)]
-> ([Int], [PatElem (LetDec rep)], [Type], [KernelResult])
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([(Int, PatElem (LetDec rep), Type, KernelResult)]
-> ([Int], [PatElem (LetDec rep)], [Type], [KernelResult]))
-> [(Int, PatElem (LetDec rep), Type, KernelResult)]
-> ([Int], [PatElem (LetDec rep)], [Type], [KernelResult])
forall a b. (a -> b) -> a -> b
$ ((Int, PatElem (LetDec rep), Type, KernelResult) -> Bool)
-> [(Int, PatElem (LetDec rep), Type, KernelResult)]
-> [(Int, PatElem (LetDec rep), Type, KernelResult)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int, PatElem (LetDec rep), Type, KernelResult) -> Bool
keep ([(Int, PatElem (LetDec rep), Type, KernelResult)]
-> [(Int, PatElem (LetDec rep), Type, KernelResult)])
-> [(Int, PatElem (LetDec rep), Type, KernelResult)]
-> [(Int, PatElem (LetDec rep), Type, KernelResult)]
forall a b. (a -> b) -> a -> b
$ [Int]
-> [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> [(Int, PatElem (LetDec rep), Type, KernelResult)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [Int
0 ..] [PatElem (LetDec rep)]
kpes [Type]
kts [KernelResult]
kres,
[PatElem (LetDec rep)]
kpes' [PatElem (LetDec rep)] -> [PatElem (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
/= [PatElem (LetDec rep)]
kpes = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
KernelBody rep
kbody' <- Scope rep
-> RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep)
forall a. Scope rep -> RuleM rep a -> RuleM rep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep))
-> RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep)
forall a b. (a -> b) -> a -> b
$ Stms (Rep (RuleM rep))
-> [KernelResult] -> RuleM rep (KernelBody (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms rep
Stms (Rep (RuleM rep))
kstms [KernelResult]
kres'
Stm (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (RuleM rep)) -> RuleM rep ())
-> Stm (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> StmAux (ExpDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep))
-> Stm (Rep (RuleM rep))
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
kpes') StmAux (ExpDec rep)
StmAux (ExpDec (Rep (RuleM rep)))
dec (Exp (Rep (RuleM rep)) -> Stm (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> Stm (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)))
-> Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> OpC rep rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> OpC rep rep)
-> SegOp (SegOpLevel rep) rep -> OpC rep rep
forall a b. (a -> b) -> a -> b
$ [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop [Type]
kts' KernelBody rep
kbody'
where
space :: SegSpace
space = SegOp (SegOpLevel rep) rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp (SegOpLevel rep) rep
segop
([Type]
kts, KernelBody BodyDec rep
_ Stms rep
kstms [KernelResult]
kres, Int
num_nonmap_results, [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop) =
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
[Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
forall rep.
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
[Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
segOpGuts SegOp (SegOpLevel rep) rep
segop
keep :: (Int, PatElem (LetDec rep), Type, KernelResult) -> Bool
keep (Int
i, PatElem (LetDec rep)
pe, Type
_, KernelResult
_) =
Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
num_nonmap_results Bool -> Bool -> Bool
|| PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> UsageTable -> Bool
`UT.used` UsageTable
used
bottomUpSegOp (SymbolTable rep
vtable, UsageTable
_used) (Pat [PatElem (LetDec rep)]
kpes) StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
segop = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') <-
Scope rep
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall a. Scope rep -> RuleM rep a -> RuleM rep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep))
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall a b. (a -> b) -> a -> b
$
(([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stm rep
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep))
-> ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stms rep
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stm rep
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
distribute ([PatElem (LetDec rep)]
kpes, [Type]
kts, [KernelResult]
kres, Stms rep
forall a. Monoid a => a
mempty) Stms rep
kstms
Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([PatElem (LetDec rep)]
kpes' [PatElem (LetDec rep)] -> [PatElem (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
kpes) RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify
KernelBody rep
kbody' <-
Scope rep
-> RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep)
forall a. Scope rep -> RuleM rep a -> RuleM rep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep))
-> RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep)
forall a b. (a -> b) -> a -> b
$ Stms (Rep (RuleM rep))
-> [KernelResult] -> RuleM rep (KernelBody (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms rep
Stms (Rep (RuleM rep))
kstms' [KernelResult]
kres'
Stm (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (RuleM rep)) -> RuleM rep ())
-> Stm (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> StmAux (ExpDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep))
-> Stm (Rep (RuleM rep))
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
kpes') StmAux (ExpDec rep)
StmAux (ExpDec (Rep (RuleM rep)))
dec (Exp (Rep (RuleM rep)) -> Stm (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> Stm (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)))
-> Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> OpC rep rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> OpC rep rep)
-> SegOp (SegOpLevel rep) rep -> OpC rep rep
forall a b. (a -> b) -> a -> b
$ [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop [Type]
kts' KernelBody rep
kbody'
where
([Type]
kts, KernelBody BodyDec rep
_ Stms rep
kstms [KernelResult]
kres, Int
num_nonmap_results, [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop) =
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
[Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
forall rep.
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
[Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
segOpGuts SegOp (SegOpLevel rep) rep
segop
free_in_kstms :: Names
free_in_kstms = (Stm rep -> Names) -> Stms rep -> Names
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stms rep
kstms
space :: SegSpace
space = SegOp (SegOpLevel rep) rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp (SegOpLevel rep) rep
segop
sliceWithGtidsFixed :: Stm rep -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm rep
stm
| Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
aux (BasicOp (Index VName
arr Slice SubExp
slice)) <- Stm rep
stm,
[DimIndex SubExp]
space_slice <- ((VName, SubExp) -> DimIndex SubExp)
-> [(VName, SubExp)] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> ((VName, SubExp) -> SubExp)
-> (VName, SubExp)
-> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> SubExp
Var (VName -> SubExp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [DimIndex SubExp])
-> [(VName, SubExp)] -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space,
[DimIndex SubExp]
space_slice [DimIndex SubExp] -> [DimIndex SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice,
Slice SubExp
remaining_slice <- [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ Int -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. Int -> [a] -> [a]
drop ([DimIndex SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
space_slice) (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice),
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Maybe (Entry rep) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry rep) -> Bool)
-> (VName -> Maybe (Entry rep)) -> VName -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName -> SymbolTable rep -> Maybe (Entry rep))
-> SymbolTable rep -> VName -> Maybe (Entry rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup SymbolTable rep
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
VName -> Names
forall a. FreeIn a => a -> Names
freeIn VName
arr Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
remaining_slice Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Certs -> Names
forall a. FreeIn a => a -> Names
freeIn (StmAux (ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux) =
(Slice SubExp, VName) -> Maybe (Slice SubExp, VName)
forall a. a -> Maybe a
Just (Slice SubExp
remaining_slice, VName
arr)
| Bool
otherwise =
Maybe (Slice SubExp, VName)
forall a. Maybe a
Nothing
distribute :: ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stm rep
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
distribute ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') Stm rep
stm
| Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ Exp rep
_ <- Stm rep
stm,
Just (Slice [DimIndex SubExp]
remaining_slice, VName
arr) <- Stm rep -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm rep
stm,
Just (PatElem (LetDec rep)
kpe, [PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> PatElem (LetDec rep)
-> Maybe
(PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
[KernelResult])
isResult [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' PatElem (LetDec rep)
pe = do
let outer_slice :: [DimIndex SubExp]
outer_slice =
(SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map
( \SubExp
d ->
SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) SubExp
d (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))
)
([SubExp] -> [DimIndex SubExp]) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
index :: PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe' =
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe'] (Exp rep -> RuleM rep ())
-> (Slice SubExp -> Exp rep) -> Slice SubExp -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep)
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> Exp rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> RuleM rep ()) -> Slice SubExp -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
[DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$
[DimIndex SubExp]
outer_slice [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. Semigroup a => a -> a -> a
<> [DimIndex SubExp]
remaining_slice
VName
precopy <- String -> RuleM rep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM rep VName) -> String -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_precopy"
PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe {patElemName = precopy}
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
precopy
([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( [PatElem (LetDec rep)]
kpes'',
[Type]
kts'',
[KernelResult]
kres'',
if PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> Names -> Bool
`nameIn` Names
free_in_kstms
then Stms rep
kstms' Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm
else Stms rep
kstms'
)
distribute ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') Stm rep
stm =
([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms' Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm)
isResult :: [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> PatElem (LetDec rep)
-> Maybe
(PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
[KernelResult])
isResult [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' PatElem (LetDec rep)
pe =
case ((PatElem (LetDec rep), Type, KernelResult) -> Bool)
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> ([(PatElem (LetDec rep), Type, KernelResult)],
[(PatElem (LetDec rep), Type, KernelResult)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (PatElem (LetDec rep), Type, KernelResult) -> Bool
matches ([(PatElem (LetDec rep), Type, KernelResult)]
-> ([(PatElem (LetDec rep), Type, KernelResult)],
[(PatElem (LetDec rep), Type, KernelResult)]))
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> ([(PatElem (LetDec rep), Type, KernelResult)],
[(PatElem (LetDec rep), Type, KernelResult)])
forall a b. (a -> b) -> a -> b
$ [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> [(PatElem (LetDec rep), Type, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' of
([(PatElem (LetDec rep)
kpe, Type
_, KernelResult
_)], [(PatElem (LetDec rep), Type, KernelResult)]
kpes_and_kres)
| Just Int
i <- PatElem (LetDec rep) -> [PatElem (LetDec rep)] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex PatElem (LetDec rep)
kpe [PatElem (LetDec rep)]
kpes,
Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
num_nonmap_results,
([PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [(PatElem (LetDec rep), Type, KernelResult)]
-> ([PatElem (LetDec rep)], [Type], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(PatElem (LetDec rep), Type, KernelResult)]
kpes_and_kres ->
(PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
[KernelResult])
-> Maybe
(PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
[KernelResult])
forall a. a -> Maybe a
Just (PatElem (LetDec rep)
kpe, [PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'')
([(PatElem (LetDec rep), Type, KernelResult)],
[(PatElem (LetDec rep), Type, KernelResult)])
_ -> Maybe
(PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
[KernelResult])
forall a. Maybe a
Nothing
where
matches :: (PatElem (LetDec rep), Type, KernelResult) -> Bool
matches (PatElem (LetDec rep)
_, Type
_, Returns ResultManifest
_ Certs
_ (Var VName
v)) = VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe
matches (PatElem (LetDec rep), Type, KernelResult)
_ = Bool
False
kernelBodyReturns ::
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep ->
[ExpReturns] ->
m [ExpReturns]
kernelBodyReturns :: forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns = (KernelResult -> ExpReturns -> m ExpReturns)
-> [KernelResult] -> [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM KernelResult -> ExpReturns -> m ExpReturns
forall {rep} {inner :: * -> *} {m :: * -> *}.
(BranchType rep ~ BranchTypeMem, LParamInfo rep ~ LParamMem,
FParamInfo rep ~ FParamMem, RetType rep ~ RetTypeMem,
OpC rep ~ MemOp inner, Monad m, HasScope rep m,
HasLetDecMem (LetDec rep), ASTRep rep, OpReturns inner,
RephraseOp inner, Pretty (inner rep), Rename (inner rep),
Substitute (inner rep), FreeIn (inner rep), Show (inner rep),
Ord (inner rep)) =>
KernelResult -> ExpReturns -> m ExpReturns
correct ([KernelResult] -> [ExpReturns] -> m [ExpReturns])
-> (KernelBody somerep -> [KernelResult])
-> KernelBody somerep
-> [ExpReturns]
-> m [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody somerep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult
where
correct :: KernelResult -> ExpReturns -> m ExpReturns
correct (WriteReturns Certs
_ VName
arr [(Slice SubExp, SubExp)]
_) ExpReturns
_ = VName -> m ExpReturns
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns VName
arr
correct KernelResult
_ ExpReturns
ret = ExpReturns -> m ExpReturns
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ExpReturns
ret
segOpReturns ::
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl rep ->
m [ExpReturns]
segOpReturns :: forall rep (inner :: * -> *) (m :: * -> *) lvl.
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl rep -> m [ExpReturns]
segOpReturns k :: SegOp lvl rep
k@(SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
kbody) =
KernelBody rep -> [ExpReturns] -> m [ExpReturns]
forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody rep
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl rep -> m [ExtType]
forall rep (m :: * -> *).
HasScope rep m =>
SegOp lvl rep -> m [ExtType]
forall (op :: * -> *) rep (m :: * -> *).
(TypedOp op, HasScope rep m) =>
op rep -> m [ExtType]
opType SegOp lvl rep
k
segOpReturns k :: SegOp lvl rep
k@(SegRed lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
KernelBody rep -> [ExpReturns] -> m [ExpReturns]
forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody rep
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl rep -> m [ExtType]
forall rep (m :: * -> *).
HasScope rep m =>
SegOp lvl rep -> m [ExtType]
forall (op :: * -> *) rep (m :: * -> *).
(TypedOp op, HasScope rep m) =>
op rep -> m [ExtType]
opType SegOp lvl rep
k
segOpReturns k :: SegOp lvl rep
k@(SegScan lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
KernelBody rep -> [ExpReturns] -> m [ExpReturns]
forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody rep
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl rep -> m [ExtType]
forall rep (m :: * -> *).
HasScope rep m =>
SegOp lvl rep -> m [ExtType]
forall (op :: * -> *) rep (m :: * -> *).
(TypedOp op, HasScope rep m) =>
op rep -> m [ExtType]
opType SegOp lvl rep
k
segOpReturns (SegHist lvl
_ SegSpace
_ [HistOp rep]
ops [Type]
_ KernelBody rep
_) =
[[ExpReturns]] -> [ExpReturns]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[ExpReturns]] -> [ExpReturns])
-> m [[ExpReturns]] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp rep -> m [ExpReturns]) -> [HistOp rep] -> m [[ExpReturns]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((VName -> m ExpReturns) -> [VName] -> m [ExpReturns]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m ExpReturns
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns ([VName] -> m [ExpReturns])
-> (HistOp rep -> [VName]) -> HistOp rep -> m [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp rep]
ops