{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Futhark.IR.SegOp
( SegOp (..),
SegVirt (..),
SegSeqDims (..),
segLevel,
segBody,
segSpace,
typeCheckSegOp,
SegSpace (..),
scopeOfSegSpace,
segSpaceDims,
HistOp (..),
histType,
splitHistResults,
SegBinOp (..),
segBinOpResults,
segBinOpChunks,
KernelBody (..),
aliasAnalyseKernelBody,
consumedInKernelBody,
ResultManifest (..),
KernelResult (..),
kernelResultCerts,
kernelResultSubExp,
SplitOrdering (..),
SegOpMapper (..),
identitySegOpMapper,
mapSegOpM,
traverseSegOpStms,
simplifySegOp,
HasSegOp (..),
segOpRules,
segOpReturns,
)
where
import Control.Category
import Control.Monad.Identity hiding (mapM_)
import Control.Monad.Reader hiding (mapM_)
import Control.Monad.State.Strict
import Control.Monad.Writer hiding (mapM_)
import Data.Bifunctor (first)
import Data.Bitraversable
import Data.Foldable (traverse_)
import Data.List
( elemIndex,
foldl',
groupBy,
intersperse,
isPrefixOf,
partition,
)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Futhark.Analysis.Alias as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.IR
import Futhark.IR.Aliases
( Aliases,
removeLambdaAliases,
removeStmAliases,
)
import Futhark.IR.Mem
import Futhark.IR.Prop.Aliases
import qualified Futhark.IR.TypeCheck as TC
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (chunks, maybeNth)
import Futhark.Util.Pretty
( Pretty,
commasep,
parens,
ppr,
text,
(<+>),
(</>),
)
import qualified Futhark.Util.Pretty as PP
import Prelude hiding (id, (.))
data SplitOrdering
= SplitContiguous
| SplitStrided SubExp
deriving (SplitOrdering -> SplitOrdering -> Bool
(SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool) -> Eq SplitOrdering
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SplitOrdering -> SplitOrdering -> Bool
$c/= :: SplitOrdering -> SplitOrdering -> Bool
== :: SplitOrdering -> SplitOrdering -> Bool
$c== :: SplitOrdering -> SplitOrdering -> Bool
Eq, Eq SplitOrdering
Eq SplitOrdering
-> (SplitOrdering -> SplitOrdering -> Ordering)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> SplitOrdering)
-> (SplitOrdering -> SplitOrdering -> SplitOrdering)
-> Ord SplitOrdering
SplitOrdering -> SplitOrdering -> Bool
SplitOrdering -> SplitOrdering -> Ordering
SplitOrdering -> SplitOrdering -> SplitOrdering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SplitOrdering -> SplitOrdering -> SplitOrdering
$cmin :: SplitOrdering -> SplitOrdering -> SplitOrdering
max :: SplitOrdering -> SplitOrdering -> SplitOrdering
$cmax :: SplitOrdering -> SplitOrdering -> SplitOrdering
>= :: SplitOrdering -> SplitOrdering -> Bool
$c>= :: SplitOrdering -> SplitOrdering -> Bool
> :: SplitOrdering -> SplitOrdering -> Bool
$c> :: SplitOrdering -> SplitOrdering -> Bool
<= :: SplitOrdering -> SplitOrdering -> Bool
$c<= :: SplitOrdering -> SplitOrdering -> Bool
< :: SplitOrdering -> SplitOrdering -> Bool
$c< :: SplitOrdering -> SplitOrdering -> Bool
compare :: SplitOrdering -> SplitOrdering -> Ordering
$ccompare :: SplitOrdering -> SplitOrdering -> Ordering
$cp1Ord :: Eq SplitOrdering
Ord, Int -> SplitOrdering -> ShowS
[SplitOrdering] -> ShowS
SplitOrdering -> String
(Int -> SplitOrdering -> ShowS)
-> (SplitOrdering -> String)
-> ([SplitOrdering] -> ShowS)
-> Show SplitOrdering
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SplitOrdering] -> ShowS
$cshowList :: [SplitOrdering] -> ShowS
show :: SplitOrdering -> String
$cshow :: SplitOrdering -> String
showsPrec :: Int -> SplitOrdering -> ShowS
$cshowsPrec :: Int -> SplitOrdering -> ShowS
Show)
instance FreeIn SplitOrdering where
freeIn' :: SplitOrdering -> FV
freeIn' SplitOrdering
SplitContiguous = FV
forall a. Monoid a => a
mempty
freeIn' (SplitStrided SubExp
stride) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
stride
instance Substitute SplitOrdering where
substituteNames :: Map VName VName -> SplitOrdering -> SplitOrdering
substituteNames Map VName VName
_ SplitOrdering
SplitContiguous =
SplitOrdering
SplitContiguous
substituteNames Map VName VName
subst (SplitStrided SubExp
stride) =
SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
stride
instance Rename SplitOrdering where
rename :: SplitOrdering -> RenameM SplitOrdering
rename SplitOrdering
SplitContiguous =
SplitOrdering -> RenameM SplitOrdering
forall (f :: * -> *) a. Applicative f => a -> f a
pure SplitOrdering
SplitContiguous
rename (SplitStrided SubExp
stride) =
SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering)
-> RenameM SubExp -> RenameM SplitOrdering
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
stride
data HistOp rep = HistOp
{ HistOp rep -> Shape
histShape :: Shape,
HistOp rep -> SubExp
histRaceFactor :: SubExp,
HistOp rep -> [VName]
histDest :: [VName],
HistOp rep -> [SubExp]
histNeutral :: [SubExp],
HistOp rep -> Shape
histOpShape :: Shape,
HistOp rep -> Lambda rep
histOp :: Lambda rep
}
deriving (HistOp rep -> HistOp rep -> Bool
(HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool) -> Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HistOp rep -> HistOp rep -> Bool
$c/= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
== :: HistOp rep -> HistOp rep -> Bool
$c== :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
Eq, Eq (HistOp rep)
Eq (HistOp rep)
-> (HistOp rep -> HistOp rep -> Ordering)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> HistOp rep)
-> (HistOp rep -> HistOp rep -> HistOp rep)
-> Ord (HistOp rep)
HistOp rep -> HistOp rep -> Bool
HistOp rep -> HistOp rep -> Ordering
HistOp rep -> HistOp rep -> HistOp rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
min :: HistOp rep -> HistOp rep -> HistOp rep
$cmin :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
max :: HistOp rep -> HistOp rep -> HistOp rep
$cmax :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
>= :: HistOp rep -> HistOp rep -> Bool
$c>= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
> :: HistOp rep -> HistOp rep -> Bool
$c> :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
<= :: HistOp rep -> HistOp rep -> Bool
$c<= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
< :: HistOp rep -> HistOp rep -> Bool
$c< :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
compare :: HistOp rep -> HistOp rep -> Ordering
$ccompare :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
$cp1Ord :: forall rep. RepTypes rep => Eq (HistOp rep)
Ord, Int -> HistOp rep -> ShowS
[HistOp rep] -> ShowS
HistOp rep -> String
(Int -> HistOp rep -> ShowS)
-> (HistOp rep -> String)
-> ([HistOp rep] -> ShowS)
-> Show (HistOp rep)
forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
forall rep. RepTypes rep => [HistOp rep] -> ShowS
forall rep. RepTypes rep => HistOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HistOp rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [HistOp rep] -> ShowS
show :: HistOp rep -> String
$cshow :: forall rep. RepTypes rep => HistOp rep -> String
showsPrec :: Int -> HistOp rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
Show)
histType :: HistOp rep -> [Type]
histType :: HistOp rep -> [Type]
histType HistOp rep
op =
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` (HistOp rep -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp rep
op Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> HistOp rep -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp rep
op)) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$
HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op
splitHistResults :: [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults :: [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp rep]
ops [SubExp]
res =
let ranks :: [Int]
ranks = (HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (Shape -> Int) -> (HistOp rep -> Shape) -> HistOp rep -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> Shape
forall rep. HistOp rep -> Shape
histShape) [HistOp rep]
ops
([SubExp]
idxs, [SubExp]
vals) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
ranks) [SubExp]
res
in [[SubExp]] -> [[SubExp]] -> [([SubExp], [SubExp])]
forall a b. [a] -> [b] -> [(a, b)]
zip
([Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ranks [SubExp]
idxs)
([Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (HistOp rep -> [VName]) -> HistOp rep -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp rep]
ops) [SubExp]
vals)
data SegBinOp rep = SegBinOp
{ SegBinOp rep -> Commutativity
segBinOpComm :: Commutativity,
SegBinOp rep -> Lambda rep
segBinOpLambda :: Lambda rep,
SegBinOp rep -> [SubExp]
segBinOpNeutral :: [SubExp],
SegBinOp rep -> Shape
segBinOpShape :: Shape
}
deriving (SegBinOp rep -> SegBinOp rep -> Bool
(SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool) -> Eq (SegBinOp rep)
forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegBinOp rep -> SegBinOp rep -> Bool
$c/= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
== :: SegBinOp rep -> SegBinOp rep -> Bool
$c== :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
Eq, Eq (SegBinOp rep)
Eq (SegBinOp rep)
-> (SegBinOp rep -> SegBinOp rep -> Ordering)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> SegBinOp rep)
-> (SegBinOp rep -> SegBinOp rep -> SegBinOp rep)
-> Ord (SegBinOp rep)
SegBinOp rep -> SegBinOp rep -> Bool
SegBinOp rep -> SegBinOp rep -> Ordering
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (SegBinOp rep)
forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Ordering
forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
min :: SegBinOp rep -> SegBinOp rep -> SegBinOp rep
$cmin :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
max :: SegBinOp rep -> SegBinOp rep -> SegBinOp rep
$cmax :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
>= :: SegBinOp rep -> SegBinOp rep -> Bool
$c>= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
> :: SegBinOp rep -> SegBinOp rep -> Bool
$c> :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
<= :: SegBinOp rep -> SegBinOp rep -> Bool
$c<= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
< :: SegBinOp rep -> SegBinOp rep -> Bool
$c< :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
compare :: SegBinOp rep -> SegBinOp rep -> Ordering
$ccompare :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Ordering
$cp1Ord :: forall rep. RepTypes rep => Eq (SegBinOp rep)
Ord, Int -> SegBinOp rep -> ShowS
[SegBinOp rep] -> ShowS
SegBinOp rep -> String
(Int -> SegBinOp rep -> ShowS)
-> (SegBinOp rep -> String)
-> ([SegBinOp rep] -> ShowS)
-> Show (SegBinOp rep)
forall rep. RepTypes rep => Int -> SegBinOp rep -> ShowS
forall rep. RepTypes rep => [SegBinOp rep] -> ShowS
forall rep. RepTypes rep => SegBinOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegBinOp rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [SegBinOp rep] -> ShowS
show :: SegBinOp rep -> String
$cshow :: forall rep. RepTypes rep => SegBinOp rep -> String
showsPrec :: Int -> SegBinOp rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> SegBinOp rep -> ShowS
Show)
segBinOpResults :: [SegBinOp rep] -> Int
segBinOpResults :: [SegBinOp rep] -> Int
segBinOpResults = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int)
-> ([SegBinOp rep] -> [Int]) -> [SegBinOp rep] -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SegBinOp rep -> Int) -> [SegBinOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp rep -> [SubExp]) -> SegBinOp rep -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral)
segBinOpChunks :: [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks :: [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks = [Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [a] -> [[a]])
-> ([SegBinOp rep] -> [Int]) -> [SegBinOp rep] -> [a] -> [[a]]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SegBinOp rep -> Int) -> [SegBinOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp rep -> [SubExp]) -> SegBinOp rep -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral)
data KernelBody rep = KernelBody
{ KernelBody rep -> BodyDec rep
kernelBodyDec :: BodyDec rep,
KernelBody rep -> Stms rep
kernelBodyStms :: Stms rep,
KernelBody rep -> [KernelResult]
kernelBodyResult :: [KernelResult]
}
deriving instance RepTypes rep => Ord (KernelBody rep)
deriving instance RepTypes rep => Show (KernelBody rep)
deriving instance RepTypes rep => Eq (KernelBody rep)
data ResultManifest
=
ResultNoSimplify
|
ResultMaySimplify
|
ResultPrivate
deriving (ResultManifest -> ResultManifest -> Bool
(ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool) -> Eq ResultManifest
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ResultManifest -> ResultManifest -> Bool
$c/= :: ResultManifest -> ResultManifest -> Bool
== :: ResultManifest -> ResultManifest -> Bool
$c== :: ResultManifest -> ResultManifest -> Bool
Eq, Int -> ResultManifest -> ShowS
[ResultManifest] -> ShowS
ResultManifest -> String
(Int -> ResultManifest -> ShowS)
-> (ResultManifest -> String)
-> ([ResultManifest] -> ShowS)
-> Show ResultManifest
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ResultManifest] -> ShowS
$cshowList :: [ResultManifest] -> ShowS
show :: ResultManifest -> String
$cshow :: ResultManifest -> String
showsPrec :: Int -> ResultManifest -> ShowS
$cshowsPrec :: Int -> ResultManifest -> ShowS
Show, Eq ResultManifest
Eq ResultManifest
-> (ResultManifest -> ResultManifest -> Ordering)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> Ord ResultManifest
ResultManifest -> ResultManifest -> Bool
ResultManifest -> ResultManifest -> Ordering
ResultManifest -> ResultManifest -> ResultManifest
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ResultManifest -> ResultManifest -> ResultManifest
$cmin :: ResultManifest -> ResultManifest -> ResultManifest
max :: ResultManifest -> ResultManifest -> ResultManifest
$cmax :: ResultManifest -> ResultManifest -> ResultManifest
>= :: ResultManifest -> ResultManifest -> Bool
$c>= :: ResultManifest -> ResultManifest -> Bool
> :: ResultManifest -> ResultManifest -> Bool
$c> :: ResultManifest -> ResultManifest -> Bool
<= :: ResultManifest -> ResultManifest -> Bool
$c<= :: ResultManifest -> ResultManifest -> Bool
< :: ResultManifest -> ResultManifest -> Bool
$c< :: ResultManifest -> ResultManifest -> Bool
compare :: ResultManifest -> ResultManifest -> Ordering
$ccompare :: ResultManifest -> ResultManifest -> Ordering
$cp1Ord :: Eq ResultManifest
Ord)
data KernelResult
=
Returns ResultManifest Certs SubExp
| WriteReturns
Certs
Shape
VName
[(Slice SubExp, SubExp)]
|
ConcatReturns
Certs
SplitOrdering
SubExp
SubExp
VName
| TileReturns
Certs
[(SubExp, SubExp)]
VName
| RegTileReturns
Certs
[ ( SubExp,
SubExp,
SubExp
)
]
VName
deriving (KernelResult -> KernelResult -> Bool
(KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool) -> Eq KernelResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelResult -> KernelResult -> Bool
$c/= :: KernelResult -> KernelResult -> Bool
== :: KernelResult -> KernelResult -> Bool
$c== :: KernelResult -> KernelResult -> Bool
Eq, Int -> KernelResult -> ShowS
[KernelResult] -> ShowS
KernelResult -> String
(Int -> KernelResult -> ShowS)
-> (KernelResult -> String)
-> ([KernelResult] -> ShowS)
-> Show KernelResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelResult] -> ShowS
$cshowList :: [KernelResult] -> ShowS
show :: KernelResult -> String
$cshow :: KernelResult -> String
showsPrec :: Int -> KernelResult -> ShowS
$cshowsPrec :: Int -> KernelResult -> ShowS
Show, Eq KernelResult
Eq KernelResult
-> (KernelResult -> KernelResult -> Ordering)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> KernelResult)
-> (KernelResult -> KernelResult -> KernelResult)
-> Ord KernelResult
KernelResult -> KernelResult -> Bool
KernelResult -> KernelResult -> Ordering
KernelResult -> KernelResult -> KernelResult
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelResult -> KernelResult -> KernelResult
$cmin :: KernelResult -> KernelResult -> KernelResult
max :: KernelResult -> KernelResult -> KernelResult
$cmax :: KernelResult -> KernelResult -> KernelResult
>= :: KernelResult -> KernelResult -> Bool
$c>= :: KernelResult -> KernelResult -> Bool
> :: KernelResult -> KernelResult -> Bool
$c> :: KernelResult -> KernelResult -> Bool
<= :: KernelResult -> KernelResult -> Bool
$c<= :: KernelResult -> KernelResult -> Bool
< :: KernelResult -> KernelResult -> Bool
$c< :: KernelResult -> KernelResult -> Bool
compare :: KernelResult -> KernelResult -> Ordering
$ccompare :: KernelResult -> KernelResult -> Ordering
$cp1Ord :: Eq KernelResult
Ord)
kernelResultCerts :: KernelResult -> Certs
kernelResultCerts :: KernelResult -> Certs
kernelResultCerts (Returns ResultManifest
_ Certs
cs SubExp
_) = Certs
cs
kernelResultCerts (WriteReturns Certs
cs Shape
_ VName
_ [(Slice SubExp, SubExp)]
_) = Certs
cs
kernelResultCerts (ConcatReturns Certs
cs SplitOrdering
_ SubExp
_ SubExp
_ VName
_) = Certs
cs
kernelResultCerts (TileReturns Certs
cs [(SubExp, SubExp)]
_ VName
_) = Certs
cs
kernelResultCerts (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
_ VName
_) = Certs
cs
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp (Returns ResultManifest
_ Certs
_ SubExp
se) = SubExp
se
kernelResultSubExp (WriteReturns Certs
_ Shape
_ VName
arr [(Slice SubExp, SubExp)]
_) = VName -> SubExp
Var VName
arr
kernelResultSubExp (ConcatReturns Certs
_ SplitOrdering
_ SubExp
_ SubExp
_ VName
v) = VName -> SubExp
Var VName
v
kernelResultSubExp (TileReturns Certs
_ [(SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v
kernelResultSubExp (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v
instance FreeIn KernelResult where
freeIn' :: KernelResult -> FV
freeIn' (Returns ResultManifest
_ Certs
cs SubExp
what) = Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
what
freeIn' (WriteReturns Certs
cs Shape
rws VName
arr [(Slice SubExp, SubExp)]
res) = Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Shape -> FV
forall a. FreeIn a => a -> FV
freeIn' Shape
rws FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
arr FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(Slice SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(Slice SubExp, SubExp)]
res
freeIn' (ConcatReturns Certs
cs SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) =
Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SplitOrdering -> FV
forall a. FreeIn a => a -> FV
freeIn' SplitOrdering
o FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
per_thread_elems FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v
freeIn' (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp)]
dims FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v
freeIn' (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(SubExp, SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp, SubExp)]
dims_n_tiles FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v
instance ASTRep rep => FreeIn (KernelBody rep) where
freeIn' :: KernelBody rep -> FV
freeIn' (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
Names -> FV -> FV
fvBind Names
bound_in_stms (FV -> FV) -> FV -> FV
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> FV
forall a. FreeIn a => a -> FV
freeIn' BodyDec rep
dec FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Stms rep -> FV
forall a. FreeIn a => a -> FV
freeIn' Stms rep
stms FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [KernelResult] -> FV
forall a. FreeIn a => a -> FV
freeIn' [KernelResult]
res
where
bound_in_stms :: Names
bound_in_stms = (Stm rep -> Names) -> Stms rep -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm rep -> Names
forall rep. Stm rep -> Names
boundByStm Stms rep
stms
instance ASTRep rep => Substitute (KernelBody rep) where
substituteNames :: Map VName VName -> KernelBody rep -> KernelBody rep
substituteNames Map VName VName
subst (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody
(Map VName VName -> BodyDec rep -> BodyDec rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst BodyDec rep
dec)
(Map VName VName -> Stms rep -> Stms rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Stms rep
stms)
(Map VName VName -> [KernelResult] -> [KernelResult]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [KernelResult]
res)
instance Substitute KernelResult where
substituteNames :: Map VName VName -> KernelResult -> KernelResult
substituteNames Map VName VName
subst (Returns ResultManifest
manifest Certs
cs SubExp
se) =
ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs) (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
se)
substituteNames Map VName VName
subst (WriteReturns Certs
cs Shape
rws VName
arr [(Slice SubExp, SubExp)]
res) =
Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns
(Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
(Map VName VName -> Shape -> Shape
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Shape
rws)
(Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
arr)
(Map VName VName
-> [(Slice SubExp, SubExp)] -> [(Slice SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(Slice SubExp, SubExp)]
res)
substituteNames Map VName VName
subst (ConcatReturns Certs
cs SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) =
Certs -> SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns
(Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
(Map VName VName -> SplitOrdering -> SplitOrdering
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SplitOrdering
o)
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
w)
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
per_thread_elems)
(Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
substituteNames Map VName VName
subst (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
Certs -> [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns
(Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
(Map VName VName -> [(SubExp, SubExp)] -> [(SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp)]
dims)
(Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
substituteNames Map VName VName
subst (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
(Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
(Map VName VName
-> [(SubExp, SubExp, SubExp)] -> [(SubExp, SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp, SubExp)]
dims_n_tiles)
(Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
instance ASTRep rep => Rename (KernelBody rep) where
rename :: KernelBody rep -> RenameM (KernelBody rep)
rename (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) = do
BodyDec rep
dec' <- BodyDec rep -> RenameM (BodyDec rep)
forall a. Rename a => a -> RenameM a
rename BodyDec rep
dec
Stms rep
-> (Stms rep -> RenameM (KernelBody rep))
-> RenameM (KernelBody rep)
forall rep a.
Renameable rep =>
Stms rep -> (Stms rep -> RenameM a) -> RenameM a
renamingStms Stms rep
stms ((Stms rep -> RenameM (KernelBody rep))
-> RenameM (KernelBody rep))
-> (Stms rep -> RenameM (KernelBody rep))
-> RenameM (KernelBody rep)
forall a b. (a -> b) -> a -> b
$ \Stms rep
stms' ->
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec' Stms rep
stms' ([KernelResult] -> KernelBody rep)
-> RenameM [KernelResult] -> RenameM (KernelBody rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [KernelResult] -> RenameM [KernelResult]
forall a. Rename a => a -> RenameM a
rename [KernelResult]
res
instance Rename KernelResult where
rename :: KernelResult -> RenameM KernelResult
rename = KernelResult -> RenameM KernelResult
forall a. Substitute a => a -> RenameM a
substituteRename
aliasAnalyseKernelBody ::
( ASTRep rep,
CanBeAliased (Op rep)
) =>
AliasTable ->
KernelBody rep ->
KernelBody (Aliases rep)
aliasAnalyseKernelBody :: AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
aliases (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
let Body BodyDec (Aliases rep)
dec' Stms (Aliases rep)
stms' Result
_ = AliasTable -> Body rep -> Body (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Body rep -> Body (Aliases rep)
Alias.analyseBody AliasTable
aliases (Body rep -> Body (Aliases rep)) -> Body rep -> Body (Aliases rep)
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms []
in BodyDec (Aliases rep)
-> Stms (Aliases rep) -> [KernelResult] -> KernelBody (Aliases rep)
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Aliases rep)
dec' Stms (Aliases rep)
stms' [KernelResult]
res
removeKernelBodyAliases ::
CanBeAliased (Op rep) =>
KernelBody (Aliases rep) ->
KernelBody rep
removeKernelBodyAliases :: KernelBody (Aliases rep) -> KernelBody rep
removeKernelBodyAliases (KernelBody (_, dec) Stms (Aliases rep)
stms [KernelResult]
res) =
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec ((Stm (Aliases rep) -> Stm rep) -> Stms (Aliases rep) -> Stms rep
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Aliases rep) -> Stm rep
forall rep. CanBeAliased (Op rep) => Stm (Aliases rep) -> Stm rep
removeStmAliases Stms (Aliases rep)
stms) [KernelResult]
res
removeKernelBodyWisdom ::
CanBeWise (Op rep) =>
KernelBody (Wise rep) ->
KernelBody rep
removeKernelBodyWisdom :: KernelBody (Wise rep) -> KernelBody rep
removeKernelBodyWisdom (KernelBody BodyDec (Wise rep)
dec Stms (Wise rep)
stms [KernelResult]
res) =
let Body BodyDec rep
dec' Stms rep
stms' Result
_ = Body (Wise rep) -> Body rep
forall rep. CanBeWise (Op rep) => Body (Wise rep) -> Body rep
removeBodyWisdom (Body (Wise rep) -> Body rep) -> Body (Wise rep) -> Body rep
forall a b. (a -> b) -> a -> b
$ BodyDec (Wise rep) -> Stms (Wise rep) -> Result -> Body (Wise rep)
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec (Wise rep)
dec Stms (Wise rep)
stms []
in BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec' Stms rep
stms' [KernelResult]
res
consumedInKernelBody ::
Aliased rep =>
KernelBody rep ->
Names
consumedInKernelBody :: KernelBody rep -> Names
consumedInKernelBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
Body rep -> Names
forall rep. Aliased rep => Body rep -> Names
consumedInBody (BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms []) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((KernelResult -> Names) -> [KernelResult] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Names
consumedByReturn [KernelResult]
res)
where
consumedByReturn :: KernelResult -> Names
consumedByReturn (WriteReturns Certs
_ Shape
_ VName
a [(Slice SubExp, SubExp)]
_) = VName -> Names
oneName VName
a
consumedByReturn KernelResult
_ = Names
forall a. Monoid a => a
mempty
checkKernelBody ::
TC.Checkable rep =>
[Type] ->
KernelBody (Aliases rep) ->
TC.TypeM rep ()
checkKernelBody :: [Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts (KernelBody (_, dec) Stms (Aliases rep)
stms [KernelResult]
kres) = do
BodyDec rep -> TypeM rep ()
forall rep. Checkable rep => BodyDec rep -> TypeM rep ()
TC.checkBodyDec BodyDec rep
dec
(KernelResult -> TypeM rep ()) -> [KernelResult] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelResult -> TypeM rep ()
forall rep. Checkable rep => KernelResult -> TypeM rep ()
consumeKernelResult [KernelResult]
kres
Stms (Aliases rep) -> TypeM rep () -> TypeM rep ()
forall rep a.
Checkable rep =>
Stms (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.checkStms Stms (Aliases rep)
stms (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ do
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
String
"Kernel return type is "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
ts
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", but body returns "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres)
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" values."
(KernelResult -> Type -> TypeM rep ())
-> [KernelResult] -> [Type] -> TypeM rep ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ KernelResult -> Type -> TypeM rep ()
forall rep. Checkable rep => KernelResult -> Type -> TypeM rep ()
checkKernelResult [KernelResult]
kres [Type]
ts
where
consumeKernelResult :: KernelResult -> TypeM rep ()
consumeKernelResult (WriteReturns Certs
_ Shape
_ VName
arr [(Slice SubExp, SubExp)]
_) =
Names -> TypeM rep ()
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume (Names -> TypeM rep ()) -> TypeM rep Names -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM rep Names
forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
arr
consumeKernelResult KernelResult
_ =
() -> TypeM rep ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
checkKernelResult :: KernelResult -> Type -> TypeM rep ()
checkKernelResult (Returns ResultManifest
_ Certs
cs SubExp
what) Type
t = do
Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [Type
t] SubExp
what
checkKernelResult (WriteReturns Certs
cs Shape
shape VName
arr [(Slice SubExp, SubExp)]
res) Type
t = do
Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
(SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
Type
arr_t <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
[(Slice SubExp, SubExp)]
-> ((Slice SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Slice SubExp, SubExp)]
res (((Slice SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ())
-> ((Slice SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
e) -> do
(SubExp -> TypeM rep ()) -> Slice SubExp -> TypeM rep ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Slice SubExp
slice
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [Type
t] SubExp
e
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> Shape -> Type
`arrayOfShape` Shape
shape) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
String
"WriteReturns returning "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
e
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" of type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
t
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", shape="
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Shape -> String
forall a. Pretty a => a -> String
pretty Shape
shape
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", but destination array has type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
arr_t
checkKernelResult (ConcatReturns Certs
cs SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) Type
t = do
Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
case SplitOrdering
o of
SplitOrdering
SplitContiguous -> () -> TypeM rep ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
SplitStrided SubExp
stride -> [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
stride
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
per_thread_elems
Type
vt <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
vt) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
String
"Invalid type for ConcatReturns " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v
checkKernelResult (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) Type
t = do
Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
[(SubExp, SubExp)]
-> ((SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(SubExp, SubExp)]
dims (((SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ())
-> ((SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
dim, SubExp
tile) -> do
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
dim
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
tile
Type
vt <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims)) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
String
"Invalid type for TileReturns " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v
checkKernelResult (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
arr) Type
t = do
Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
(SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
dims
(SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
blk_tiles
(SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
reg_tiles
Type
arr_t <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
expected) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (String -> ErrorCase rep) -> String -> TypeM rep ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> TypeM rep ()) -> String -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
String
"Invalid type for TileReturns. Expected:\n "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
expected
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
",\ngot:\n "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
arr_t
where
([SubExp]
dims, [SubExp]
blk_tiles, [SubExp]
reg_tiles) = [(SubExp, SubExp, SubExp)] -> ([SubExp], [SubExp], [SubExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, SubExp)]
dims_n_tiles
expected :: Type
expected = Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp]
blk_tiles [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
reg_tiles)
kernelBodyMetrics :: OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics :: KernelBody rep -> MetricsM ()
kernelBodyMetrics = (Stm rep -> MetricsM ()) -> Seq (Stm rep) -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Stm rep -> MetricsM ()
stmMetrics (Seq (Stm rep) -> MetricsM ())
-> (KernelBody rep -> Seq (Stm rep))
-> KernelBody rep
-> MetricsM ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody rep -> Seq (Stm rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms
instance PrettyRep rep => Pretty (KernelBody rep) where
ppr :: KernelBody rep -> Doc
ppr (KernelBody BodyDec rep
_ Stms rep
stms [KernelResult]
res) =
[Doc] -> Doc
PP.stack ((Stm rep -> Doc) -> [Stm rep] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Stm rep -> Doc
forall a. Pretty a => a -> Doc
ppr (Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms))
Doc -> Doc -> Doc
</> String -> Doc
text String
"return"
Doc -> Doc -> Doc
<+> Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (KernelResult -> Doc) -> [KernelResult] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Doc
forall a. Pretty a => a -> Doc
ppr [KernelResult]
res)
certAnnots :: Certs -> [PP.Doc]
certAnnots :: Certs -> [Doc]
certAnnots Certs
cs
| Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty = []
| Bool
otherwise = [Certs -> Doc
forall a. Pretty a => a -> Doc
ppr Certs
cs]
instance Pretty KernelResult where
ppr :: KernelResult -> Doc
ppr (Returns ResultManifest
ResultNoSimplify Certs
cs SubExp
what) =
[Doc] -> Doc
PP.spread ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc]
certAnnots Certs
cs [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [Doc
"returns (manifest)" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what]
ppr (Returns ResultManifest
ResultPrivate Certs
cs SubExp
what) =
[Doc] -> Doc
PP.spread ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc]
certAnnots Certs
cs [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [Doc
"returns (private)" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what]
ppr (Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
what) =
[Doc] -> Doc
PP.spread ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc]
certAnnots Certs
cs [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [Doc
"returns" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what]
ppr (WriteReturns Certs
cs Shape
shape VName
arr [(Slice SubExp, SubExp)]
res) =
[Doc] -> Doc
PP.spread ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$
Certs -> [Doc]
certAnnots Certs
cs
[Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [ VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
arr
Doc -> Doc -> Doc
<+> Doc
PP.colon
Doc -> Doc -> Doc
<+> Shape -> Doc
forall a. Pretty a => a -> Doc
ppr Shape
shape
Doc -> Doc -> Doc
</> Doc
"with"
Doc -> Doc -> Doc
<+> [Doc] -> Doc
PP.apply (((Slice SubExp, SubExp) -> Doc)
-> [(Slice SubExp, SubExp)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (Slice SubExp, SubExp) -> Doc
forall a a. (Pretty a, Pretty a) => (a, a) -> Doc
ppRes [(Slice SubExp, SubExp)]
res)
]
where
ppRes :: (a, a) -> Doc
ppRes (a
slice, a
e) = a -> Doc
forall a. Pretty a => a -> Doc
ppr a
slice Doc -> Doc -> Doc
<+> String -> Doc
text String
"=" Doc -> Doc -> Doc
<+> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
e
ppr (ConcatReturns Certs
cs SplitOrdering
SplitContiguous SubExp
w SubExp
per_thread_elems VName
v) =
[Doc] -> Doc
PP.spread ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$
Certs -> [Doc]
certAnnots Certs
cs
[Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [ Doc
"concat"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
per_thread_elems])
Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
]
ppr (ConcatReturns Certs
cs (SplitStrided SubExp
stride) SubExp
w SubExp
per_thread_elems VName
v) =
[Doc] -> Doc
PP.spread ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$
Certs -> [Doc]
certAnnots Certs
cs
[Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [ Doc
"concat_strided"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
stride, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
per_thread_elems])
Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
]
ppr (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
[Doc] -> Doc
PP.spread ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc]
certAnnots Certs
cs [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [Doc
"tile" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp) -> Doc) -> [(SubExp, SubExp)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> Doc
forall a a. (Pretty a, Pretty a) => (a, a) -> Doc
onDim [(SubExp, SubExp)]
dims) Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v]
where
onDim :: (a, a) -> Doc
onDim (a
dim, a
tile) = a -> Doc
forall a. Pretty a => a -> Doc
ppr a
dim Doc -> Doc -> Doc
<+> Doc
"/" Doc -> Doc -> Doc
<+> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
tile
ppr (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
[Doc] -> Doc
PP.spread ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc]
certAnnots Certs
cs [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [Doc
"blkreg_tile" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp, SubExp) -> Doc)
-> [(SubExp, SubExp, SubExp)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp, SubExp) -> Doc
forall a a a. (Pretty a, Pretty a, Pretty a) => (a, a, a) -> Doc
onDim [(SubExp, SubExp, SubExp)]
dims_n_tiles) Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v]
where
onDim :: (a, a, a) -> Doc
onDim (a
dim, a
blk_tile, a
reg_tile) =
a -> Doc
forall a. Pretty a => a -> Doc
ppr a
dim Doc -> Doc -> Doc
<+> Doc
"/" Doc -> Doc -> Doc
<+> Doc -> Doc
parens (a -> Doc
forall a. Pretty a => a -> Doc
ppr a
blk_tile Doc -> Doc -> Doc
<+> Doc
"*" Doc -> Doc -> Doc
<+> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
reg_tile)
newtype SegSeqDims = SegSeqDims {SegSeqDims -> [Int]
segSeqDims :: [Int]}
deriving (SegSeqDims -> SegSeqDims -> Bool
(SegSeqDims -> SegSeqDims -> Bool)
-> (SegSeqDims -> SegSeqDims -> Bool) -> Eq SegSeqDims
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegSeqDims -> SegSeqDims -> Bool
$c/= :: SegSeqDims -> SegSeqDims -> Bool
== :: SegSeqDims -> SegSeqDims -> Bool
$c== :: SegSeqDims -> SegSeqDims -> Bool
Eq, Eq SegSeqDims
Eq SegSeqDims
-> (SegSeqDims -> SegSeqDims -> Ordering)
-> (SegSeqDims -> SegSeqDims -> Bool)
-> (SegSeqDims -> SegSeqDims -> Bool)
-> (SegSeqDims -> SegSeqDims -> Bool)
-> (SegSeqDims -> SegSeqDims -> Bool)
-> (SegSeqDims -> SegSeqDims -> SegSeqDims)
-> (SegSeqDims -> SegSeqDims -> SegSeqDims)
-> Ord SegSeqDims
SegSeqDims -> SegSeqDims -> Bool
SegSeqDims -> SegSeqDims -> Ordering
SegSeqDims -> SegSeqDims -> SegSeqDims
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegSeqDims -> SegSeqDims -> SegSeqDims
$cmin :: SegSeqDims -> SegSeqDims -> SegSeqDims
max :: SegSeqDims -> SegSeqDims -> SegSeqDims
$cmax :: SegSeqDims -> SegSeqDims -> SegSeqDims
>= :: SegSeqDims -> SegSeqDims -> Bool
$c>= :: SegSeqDims -> SegSeqDims -> Bool
> :: SegSeqDims -> SegSeqDims -> Bool
$c> :: SegSeqDims -> SegSeqDims -> Bool
<= :: SegSeqDims -> SegSeqDims -> Bool
$c<= :: SegSeqDims -> SegSeqDims -> Bool
< :: SegSeqDims -> SegSeqDims -> Bool
$c< :: SegSeqDims -> SegSeqDims -> Bool
compare :: SegSeqDims -> SegSeqDims -> Ordering
$ccompare :: SegSeqDims -> SegSeqDims -> Ordering
$cp1Ord :: Eq SegSeqDims
Ord, Int -> SegSeqDims -> ShowS
[SegSeqDims] -> ShowS
SegSeqDims -> String
(Int -> SegSeqDims -> ShowS)
-> (SegSeqDims -> String)
-> ([SegSeqDims] -> ShowS)
-> Show SegSeqDims
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegSeqDims] -> ShowS
$cshowList :: [SegSeqDims] -> ShowS
show :: SegSeqDims -> String
$cshow :: SegSeqDims -> String
showsPrec :: Int -> SegSeqDims -> ShowS
$cshowsPrec :: Int -> SegSeqDims -> ShowS
Show)
data SegVirt
= SegVirt
| SegNoVirt
|
SegNoVirtFull SegSeqDims
deriving (SegVirt -> SegVirt -> Bool
(SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool) -> Eq SegVirt
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegVirt -> SegVirt -> Bool
$c/= :: SegVirt -> SegVirt -> Bool
== :: SegVirt -> SegVirt -> Bool
$c== :: SegVirt -> SegVirt -> Bool
Eq, Eq SegVirt
Eq SegVirt
-> (SegVirt -> SegVirt -> Ordering)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> SegVirt)
-> (SegVirt -> SegVirt -> SegVirt)
-> Ord SegVirt
SegVirt -> SegVirt -> Bool
SegVirt -> SegVirt -> Ordering
SegVirt -> SegVirt -> SegVirt
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegVirt -> SegVirt -> SegVirt
$cmin :: SegVirt -> SegVirt -> SegVirt
max :: SegVirt -> SegVirt -> SegVirt
$cmax :: SegVirt -> SegVirt -> SegVirt
>= :: SegVirt -> SegVirt -> Bool
$c>= :: SegVirt -> SegVirt -> Bool
> :: SegVirt -> SegVirt -> Bool
$c> :: SegVirt -> SegVirt -> Bool
<= :: SegVirt -> SegVirt -> Bool
$c<= :: SegVirt -> SegVirt -> Bool
< :: SegVirt -> SegVirt -> Bool
$c< :: SegVirt -> SegVirt -> Bool
compare :: SegVirt -> SegVirt -> Ordering
$ccompare :: SegVirt -> SegVirt -> Ordering
$cp1Ord :: Eq SegVirt
Ord, Int -> SegVirt -> ShowS
[SegVirt] -> ShowS
SegVirt -> String
(Int -> SegVirt -> ShowS)
-> (SegVirt -> String) -> ([SegVirt] -> ShowS) -> Show SegVirt
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegVirt] -> ShowS
$cshowList :: [SegVirt] -> ShowS
show :: SegVirt -> String
$cshow :: SegVirt -> String
showsPrec :: Int -> SegVirt -> ShowS
$cshowsPrec :: Int -> SegVirt -> ShowS
Show)
data SegSpace = SegSpace
{
SegSpace -> VName
segFlat :: VName,
SegSpace -> [(VName, SubExp)]
unSegSpace :: [(VName, SubExp)]
}
deriving (SegSpace -> SegSpace -> Bool
(SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool) -> Eq SegSpace
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegSpace -> SegSpace -> Bool
$c/= :: SegSpace -> SegSpace -> Bool
== :: SegSpace -> SegSpace -> Bool
$c== :: SegSpace -> SegSpace -> Bool
Eq, Eq SegSpace
Eq SegSpace
-> (SegSpace -> SegSpace -> Ordering)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> SegSpace)
-> (SegSpace -> SegSpace -> SegSpace)
-> Ord SegSpace
SegSpace -> SegSpace -> Bool
SegSpace -> SegSpace -> Ordering
SegSpace -> SegSpace -> SegSpace
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegSpace -> SegSpace -> SegSpace
$cmin :: SegSpace -> SegSpace -> SegSpace
max :: SegSpace -> SegSpace -> SegSpace
$cmax :: SegSpace -> SegSpace -> SegSpace
>= :: SegSpace -> SegSpace -> Bool
$c>= :: SegSpace -> SegSpace -> Bool
> :: SegSpace -> SegSpace -> Bool
$c> :: SegSpace -> SegSpace -> Bool
<= :: SegSpace -> SegSpace -> Bool
$c<= :: SegSpace -> SegSpace -> Bool
< :: SegSpace -> SegSpace -> Bool
$c< :: SegSpace -> SegSpace -> Bool
compare :: SegSpace -> SegSpace -> Ordering
$ccompare :: SegSpace -> SegSpace -> Ordering
$cp1Ord :: Eq SegSpace
Ord, Int -> SegSpace -> ShowS
[SegSpace] -> ShowS
SegSpace -> String
(Int -> SegSpace -> ShowS)
-> (SegSpace -> String) -> ([SegSpace] -> ShowS) -> Show SegSpace
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegSpace] -> ShowS
$cshowList :: [SegSpace] -> ShowS
show :: SegSpace -> String
$cshow :: SegSpace -> String
showsPrec :: Int -> SegSpace -> ShowS
$cshowsPrec :: Int -> SegSpace -> ShowS
Show)
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims (SegSpace VName
_ [(VName, SubExp)]
space) = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
space
scopeOfSegSpace :: SegSpace -> Scope rep
scopeOfSegSpace :: SegSpace -> Scope rep
scopeOfSegSpace (SegSpace VName
phys [(VName, SubExp)]
space) =
[(VName, NameInfo rep)] -> Scope rep
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, NameInfo rep)] -> Scope rep)
-> [(VName, NameInfo rep)] -> Scope rep
forall a b. (a -> b) -> a -> b
$ [VName] -> [NameInfo rep] -> [(VName, NameInfo rep)]
forall a b. [a] -> [b] -> [(a, b)]
zip (VName
phys VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
space) ([NameInfo rep] -> [(VName, NameInfo rep)])
-> [NameInfo rep] -> [(VName, NameInfo rep)]
forall a b. (a -> b) -> a -> b
$ NameInfo rep -> [NameInfo rep]
forall a. a -> [a]
repeat (NameInfo rep -> [NameInfo rep]) -> NameInfo rep -> [NameInfo rep]
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo rep
forall rep. IntType -> NameInfo rep
IndexName IntType
Int64
checkSegSpace :: TC.Checkable rep => SegSpace -> TC.TypeM rep ()
checkSegSpace :: SegSpace -> TypeM rep ()
checkSegSpace (SegSpace VName
_ [(VName, SubExp)]
dims) =
((VName, SubExp) -> TypeM rep ())
-> [(VName, SubExp)] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] (SubExp -> TypeM rep ())
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TypeM rep ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
dims
data SegOp lvl rep
= SegMap lvl SegSpace [Type] (KernelBody rep)
|
SegRed lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep)
| SegScan lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep)
| SegHist lvl SegSpace [HistOp rep] [Type] (KernelBody rep)
deriving (SegOp lvl rep -> SegOp lvl rep -> Bool
(SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool) -> Eq (SegOp lvl rep)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
/= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c/= :: forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
== :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c== :: forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
Eq, Eq (SegOp lvl rep)
Eq (SegOp lvl rep)
-> (SegOp lvl rep -> SegOp lvl rep -> Ordering)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep)
-> (SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep)
-> Ord (SegOp lvl rep)
SegOp lvl rep -> SegOp lvl rep -> Bool
SegOp lvl rep -> SegOp lvl rep -> Ordering
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lvl rep. (RepTypes rep, Ord lvl) => Eq (SegOp lvl rep)
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Ordering
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
min :: SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
$cmin :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
max :: SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
$cmax :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
>= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c>= :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
> :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c> :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
<= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c<= :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
< :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c< :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
compare :: SegOp lvl rep -> SegOp lvl rep -> Ordering
$ccompare :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Ordering
$cp1Ord :: forall lvl rep. (RepTypes rep, Ord lvl) => Eq (SegOp lvl rep)
Ord, Int -> SegOp lvl rep -> ShowS
[SegOp lvl rep] -> ShowS
SegOp lvl rep -> String
(Int -> SegOp lvl rep -> ShowS)
-> (SegOp lvl rep -> String)
-> ([SegOp lvl rep] -> ShowS)
-> Show (SegOp lvl rep)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall lvl rep.
(RepTypes rep, Show lvl) =>
Int -> SegOp lvl rep -> ShowS
forall lvl rep.
(RepTypes rep, Show lvl) =>
[SegOp lvl rep] -> ShowS
forall lvl rep. (RepTypes rep, Show lvl) => SegOp lvl rep -> String
showList :: [SegOp lvl rep] -> ShowS
$cshowList :: forall lvl rep.
(RepTypes rep, Show lvl) =>
[SegOp lvl rep] -> ShowS
show :: SegOp lvl rep -> String
$cshow :: forall lvl rep. (RepTypes rep, Show lvl) => SegOp lvl rep -> String
showsPrec :: Int -> SegOp lvl rep -> ShowS
$cshowsPrec :: forall lvl rep.
(RepTypes rep, Show lvl) =>
Int -> SegOp lvl rep -> ShowS
Show)
segLevel :: SegOp lvl rep -> lvl
segLevel :: SegOp lvl rep -> lvl
segLevel (SegMap lvl
lvl SegSpace
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegRed lvl
lvl SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegScan lvl
lvl SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegHist lvl
lvl SegSpace
_ [HistOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segSpace :: SegOp lvl rep -> SegSpace
segSpace :: SegOp lvl rep -> SegSpace
segSpace (SegMap lvl
_ SegSpace
lvl [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegRed lvl
_ SegSpace
lvl [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegScan lvl
_ SegSpace
lvl [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegHist lvl
_ SegSpace
lvl [HistOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segBody :: SegOp lvl rep -> KernelBody rep
segBody :: SegOp lvl rep -> KernelBody rep
segBody SegOp lvl rep
segop =
case SegOp lvl rep
segop of
SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
SegRed lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
SegScan lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
SegHist lvl
_ SegSpace
_ [HistOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
_ Type
t (WriteReturns Certs
_ Shape
shape VName
_ [(Slice SubExp, SubExp)]
_) =
Type
t Type -> Shape -> Type
`arrayOfShape` Shape
shape
segResultShape SegSpace
space Type
t Returns {} =
(SubExp -> Type -> Type) -> Type -> [SubExp] -> Type
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Type -> SubExp -> Type) -> SubExp -> Type -> Type
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow) Type
t ([SubExp] -> Type) -> [SubExp] -> Type
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segResultShape SegSpace
_ Type
t (ConcatReturns Certs
_ SplitOrdering
_ SubExp
w SubExp
_ VName
_) =
Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w
segResultShape SegSpace
_ Type
t (TileReturns Certs
_ [(SubExp, SubExp)]
dims VName
_) =
Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims)
segResultShape SegSpace
_ Type
t (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
_) =
Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp, SubExp) -> SubExp)
-> [(SubExp, SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (\(SubExp
dim, SubExp
_, SubExp
_) -> SubExp
dim) [(SubExp, SubExp, SubExp)]
dims_n_tiles)
segOpType :: SegOp lvl rep -> [Type]
segOpType :: SegOp lvl rep -> [Type]
segOpType (SegMap lvl
_ SegSpace
space [Type]
ts KernelBody rep
kbody) =
(Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
ts ([KernelResult] -> [Type]) -> [KernelResult] -> [Type]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
segOpType (SegRed lvl
_ SegSpace
space [SegBinOp rep]
reds [Type]
ts KernelBody rep
kbody) =
[Type]
red_ts
[Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
[Type]
map_ts
(Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody)
where
map_ts :: [Type]
map_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) [Type]
ts
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
red_ts :: [Type]
red_ts = do
SegBinOp rep
op <- [SegBinOp rep]
reds
let shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> SegBinOp rep -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp rep
op
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op)
segOpType (SegScan lvl
_ SegSpace
space [SegBinOp rep]
scans [Type]
ts KernelBody rep
kbody) =
[Type]
scan_ts
[Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
[Type]
map_ts
(Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody)
where
map_ts :: [Type]
map_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) [Type]
ts
scan_ts :: [Type]
scan_ts = do
SegBinOp rep
op <- [SegBinOp rep]
scans
let shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (SegSpace -> [SubExp]
segSpaceDims SegSpace
space) Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> SegBinOp rep -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp rep
op
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op)
segOpType (SegHist lvl
_ SegSpace
space [HistOp rep]
ops [Type]
_ KernelBody rep
_) = do
HistOp rep
op <- [HistOp rep]
ops
let shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> HistOp rep -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp rep
op Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> HistOp rep -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp rep
op
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op)
where
dims :: [SubExp]
dims = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
dims
instance TypedOp (SegOp lvl rep) where
opType :: SegOp lvl rep -> m [ExtType]
opType = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType])
-> (SegOp lvl rep -> [ExtType]) -> SegOp lvl rep -> m [ExtType]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType])
-> (SegOp lvl rep -> [Type]) -> SegOp lvl rep -> [ExtType]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOp lvl rep -> [Type]
forall lvl rep. SegOp lvl rep -> [Type]
segOpType
instance
(ASTRep rep, Aliased rep, ASTConstraints lvl) =>
AliasedOp (SegOp lvl rep)
where
opAliases :: SegOp lvl rep -> [Names]
opAliases = (Type -> Names) -> [Type] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Type -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([Type] -> [Names])
-> (SegOp lvl rep -> [Type]) -> SegOp lvl rep -> [Names]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOp lvl rep -> [Type]
forall lvl rep. SegOp lvl rep -> [Type]
segOpType
consumedInOp :: SegOp lvl rep -> Names
consumedInOp (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
kbody) =
KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
consumedInOp (SegRed lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
consumedInOp (SegScan lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
consumedInOp (SegHist lvl
_ SegSpace
_ [HistOp rep]
ops [Type]
_ KernelBody rep
kbody) =
[VName] -> Names
namesFromList ((HistOp rep -> [VName]) -> [HistOp rep] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest [HistOp rep]
ops) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
typeCheckSegOp ::
TC.Checkable rep =>
(lvl -> TC.TypeM rep ()) ->
SegOp lvl (Aliases rep) ->
TC.TypeM rep ()
typeCheckSegOp :: (lvl -> TypeM rep ()) -> SegOp lvl (Aliases rep) -> TypeM rep ()
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Aliases rep)
kbody) = do
lvl -> TypeM rep ()
checkLvl lvl
lvl
SegSpace
-> [(Lambda (Aliases rep), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [] [Type]
ts KernelBody (Aliases rep)
kbody
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegRed lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
reds [Type]
ts KernelBody (Aliases rep)
body) = do
lvl -> TypeM rep ()
checkLvl lvl
lvl
SegSpace
-> [(Lambda (Aliases rep), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], Shape)]
reds' [Type]
ts KernelBody (Aliases rep)
body
where
reds' :: [(Lambda (Aliases rep), [SubExp], Shape)]
reds' =
[Lambda (Aliases rep)]
-> [[SubExp]]
-> [Shape]
-> [(Lambda (Aliases rep), [SubExp], Shape)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
((SegBinOp (Aliases rep) -> Lambda (Aliases rep))
-> [SegBinOp (Aliases rep)] -> [Lambda (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> Lambda (Aliases rep)
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp (Aliases rep)]
reds)
((SegBinOp (Aliases rep) -> [SubExp])
-> [SegBinOp (Aliases rep)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases rep)]
reds)
((SegBinOp (Aliases rep) -> Shape)
-> [SegBinOp (Aliases rep)] -> [Shape]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape [SegBinOp (Aliases rep)]
reds)
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegScan lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
scans [Type]
ts KernelBody (Aliases rep)
body) = do
lvl -> TypeM rep ()
checkLvl lvl
lvl
SegSpace
-> [(Lambda (Aliases rep), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], Shape)]
scans' [Type]
ts KernelBody (Aliases rep)
body
where
scans' :: [(Lambda (Aliases rep), [SubExp], Shape)]
scans' =
[Lambda (Aliases rep)]
-> [[SubExp]]
-> [Shape]
-> [(Lambda (Aliases rep), [SubExp], Shape)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
((SegBinOp (Aliases rep) -> Lambda (Aliases rep))
-> [SegBinOp (Aliases rep)] -> [Lambda (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> Lambda (Aliases rep)
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp (Aliases rep)]
scans)
((SegBinOp (Aliases rep) -> [SubExp])
-> [SegBinOp (Aliases rep)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases rep)]
scans)
((SegBinOp (Aliases rep) -> Shape)
-> [SegBinOp (Aliases rep)] -> [Shape]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape [SegBinOp (Aliases rep)]
scans)
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegHist lvl
lvl SegSpace
space [HistOp (Aliases rep)]
ops [Type]
ts KernelBody (Aliases rep)
kbody) = do
lvl -> TypeM rep ()
checkLvl lvl
lvl
SegSpace -> TypeM rep ()
forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace SegSpace
space
(Type -> TypeM rep ()) -> [Type] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM rep ()
forall rep u. Checkable rep => TypeBase Shape u -> TypeM rep ()
TC.checkType [Type]
ts
Scope (Aliases rep) -> TypeM rep () -> TypeM rep ()
forall rep a.
Checkable rep =>
Scope (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.binding (SegSpace -> Scope (Aliases rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ do
[[Type]]
nes_ts <- [HistOp (Aliases rep)]
-> (HistOp (Aliases rep) -> TypeM rep [Type]) -> TypeM rep [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Aliases rep)]
ops ((HistOp (Aliases rep) -> TypeM rep [Type]) -> TypeM rep [[Type]])
-> (HistOp (Aliases rep) -> TypeM rep [Type]) -> TypeM rep [[Type]]
forall a b. (a -> b) -> a -> b
$ \(HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes Shape
shape Lambda (Aliases rep)
op) -> do
(SubExp -> TypeM rep ()) -> Shape -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Shape
dest_shape
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
rf
[Arg]
nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes
(SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
let stripVecDims :: Type -> Type
stripVecDims = Int -> Type -> Type
forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray (Int -> Type -> Type) -> Int -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
op ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map (Arg -> Arg
TC.noArgAliases (Arg -> Arg) -> (Arg -> Arg) -> Arg -> Arg
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Type -> Type) -> Arg -> Arg
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Type -> Type
stripVecDims) ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
String
"SegHist operator has return type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op)
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
nes_t
let dest_shape' :: Shape
dest_shape' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
dest_shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape
[(Type, VName)] -> ((Type, VName) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) (((Type, VName) -> TypeM rep ()) -> TypeM rep ())
-> ((Type, VName) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
[Type] -> VName -> TypeM rep ()
forall rep. Checkable rep => [Type] -> VName -> TypeM rep ()
TC.requireI [Type
t Type -> Shape -> Type
`arrayOfShape` Shape
dest_shape'] VName
dest
Names -> TypeM rep ()
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume (Names -> TypeM rep ()) -> TypeM rep Names -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM rep Names
forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
dest
[Type] -> TypeM rep [Type]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type] -> TypeM rep [Type]) -> [Type] -> TypeM rep [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) [Type]
nes_t
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts KernelBody (Aliases rep)
kbody
let bucket_ret_t :: [Type]
bucket_ret_t =
(HistOp (Aliases rep) -> [Type])
-> [HistOp (Aliases rep)] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Int -> Type -> [Type]
forall a. Int -> a -> [a]
`replicate` PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) (Int -> [Type])
-> (HistOp (Aliases rep) -> Int) -> HistOp (Aliases rep) -> [Type]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (Shape -> Int)
-> (HistOp (Aliases rep) -> Shape) -> HistOp (Aliases rep) -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp (Aliases rep) -> Shape
forall rep. HistOp rep -> Shape
histShape) [HistOp (Aliases rep)]
ops
[Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
nes_ts
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
ts) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
String
"SegHist body has return type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
ts
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but should have type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
bucket_ret_t
where
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
checkScanRed ::
TC.Checkable rep =>
SegSpace ->
[(Lambda (Aliases rep), [SubExp], Shape)] ->
[Type] ->
KernelBody (Aliases rep) ->
TC.TypeM rep ()
checkScanRed :: SegSpace
-> [(Lambda (Aliases rep), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], Shape)]
ops [Type]
ts KernelBody (Aliases rep)
kbody = do
SegSpace -> TypeM rep ()
forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace SegSpace
space
(Type -> TypeM rep ()) -> [Type] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM rep ()
forall rep u. Checkable rep => TypeBase Shape u -> TypeM rep ()
TC.checkType [Type]
ts
Scope (Aliases rep) -> TypeM rep () -> TypeM rep ()
forall rep a.
Checkable rep =>
Scope (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.binding (SegSpace -> Scope (Aliases rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ do
[[Type]]
ne_ts <- [(Lambda (Aliases rep), [SubExp], Shape)]
-> ((Lambda (Aliases rep), [SubExp], Shape) -> TypeM rep [Type])
-> TypeM rep [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Lambda (Aliases rep), [SubExp], Shape)]
ops (((Lambda (Aliases rep), [SubExp], Shape) -> TypeM rep [Type])
-> TypeM rep [[Type]])
-> ((Lambda (Aliases rep), [SubExp], Shape) -> TypeM rep [Type])
-> TypeM rep [[Type]]
forall a b. (a -> b) -> a -> b
$ \(Lambda (Aliases rep)
lam, [SubExp]
nes, Shape
shape) -> do
(SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
[Arg]
nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
nes_t) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError String
"wrong type for operator or neutral elements."
[Type] -> TypeM rep [Type]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type] -> TypeM rep [Type]) -> [Type] -> TypeM rep [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) [Type]
nes_t
let expecting :: [Type]
expecting = [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
ne_ts
got :: [Type]
got = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
expecting) [Type]
ts
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
expecting [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
got) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
String
"Wrong return for body (does not match neutral elements; expected "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => a -> String
pretty [Type]
expecting
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"; found "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => a -> String
pretty [Type]
got
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts KernelBody (Aliases rep)
kbody
data SegOpMapper lvl frep trep m = SegOpMapper
{ SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp :: SubExp -> m SubExp,
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda :: Lambda frep -> m (Lambda trep),
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody :: KernelBody frep -> m (KernelBody trep),
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName :: VName -> m VName,
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel :: lvl -> m lvl
}
identitySegOpMapper :: Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper :: SegOpMapper lvl rep rep m
identitySegOpMapper =
SegOpMapper :: forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
{ mapOnSegOpSubExp :: SubExp -> m SubExp
mapOnSegOpSubExp = SubExp -> m SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnSegOpLambda :: Lambda rep -> m (Lambda rep)
mapOnSegOpLambda = Lambda rep -> m (Lambda rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnSegOpBody :: KernelBody rep -> m (KernelBody rep)
mapOnSegOpBody = KernelBody rep -> m (KernelBody rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnSegOpVName :: VName -> m VName
mapOnSegOpVName = VName -> m VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
mapOnSegOpLevel :: lvl -> m lvl
mapOnSegOpLevel = lvl -> m lvl
forall (f :: * -> *) a. Applicative f => a -> f a
pure
}
mapOnSegSpace ::
Monad f => SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace :: SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep f
tv (SegSpace VName
phys [(VName, SubExp)]
dims) =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace
(VName -> [(VName, SubExp)] -> SegSpace)
-> f VName -> f ([(VName, SubExp)] -> SegSpace)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep f -> VName -> f VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep f
tv VName
phys
f ([(VName, SubExp)] -> SegSpace)
-> f [(VName, SubExp)] -> f SegSpace
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((VName, SubExp) -> f (VName, SubExp))
-> [(VName, SubExp)] -> f [(VName, SubExp)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((VName -> f VName)
-> (SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse (SegOpMapper lvl frep trep f -> VName -> f VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep f
tv) (SegOpMapper lvl frep trep f -> SubExp -> f SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep f
tv)) [(VName, SubExp)]
dims
mapSegBinOp ::
Monad m =>
SegOpMapper lvl frep trep m ->
SegBinOp frep ->
m (SegBinOp trep)
mapSegBinOp :: SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv (SegBinOp Commutativity
comm Lambda frep
red_op [SubExp]
nes Shape
shape) =
Commutativity -> Lambda trep -> [SubExp] -> Shape -> SegBinOp trep
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm
(Lambda trep -> [SubExp] -> Shape -> SegBinOp trep)
-> m (Lambda trep) -> m ([SubExp] -> Shape -> SegBinOp trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda SegOpMapper lvl frep trep m
tv Lambda frep
red_op
m ([SubExp] -> Shape -> SegBinOp trep)
-> m [SubExp] -> m (Shape -> SegBinOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [SubExp]
nes
m (Shape -> SegBinOp trep) -> m Shape -> m (SegBinOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> m [SubExp] -> m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape))
mapSegOpM ::
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m ->
SegOp lvl frep ->
m (SegOp lvl trep)
mapSegOpM :: SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl frep trep m
tv (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody frep
body) =
lvl -> SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap
(lvl -> SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m lvl
-> m (SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
m (SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace -> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl frep trep m -> Type -> m Type
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl frep trep m
tv) [Type]
ts
m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
mapSegOpM SegOpMapper lvl frep trep m
tv (SegRed lvl
lvl SegSpace
space [SegBinOp frep]
reds [Type]
ts KernelBody frep
lam) =
lvl
-> SegSpace
-> [SegBinOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed
(lvl
-> SegSpace
-> [SegBinOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep)
-> m lvl
-> m (SegSpace
-> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
m (SegSpace
-> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace
-> m ([SegBinOp trep]
-> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
m ([SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m [SegBinOp trep]
-> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SegBinOp frep -> m (SegBinOp trep))
-> [SegBinOp frep] -> m [SegBinOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv) [SegBinOp frep]
reds
m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp) -> TypeBase Shape u -> m (TypeBase Shape u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
lam
mapSegOpM SegOpMapper lvl frep trep m
tv (SegScan lvl
lvl SegSpace
space [SegBinOp frep]
scans [Type]
ts KernelBody frep
body) =
lvl
-> SegSpace
-> [SegBinOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan
(lvl
-> SegSpace
-> [SegBinOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep)
-> m lvl
-> m (SegSpace
-> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
m (SegSpace
-> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace
-> m ([SegBinOp trep]
-> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
m ([SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m [SegBinOp trep]
-> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SegBinOp frep -> m (SegBinOp trep))
-> [SegBinOp frep] -> m [SegBinOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv) [SegBinOp frep]
scans
m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp) -> TypeBase Shape u -> m (TypeBase Shape u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
mapSegOpM SegOpMapper lvl frep trep m
tv (SegHist lvl
lvl SegSpace
space [HistOp frep]
ops [Type]
ts KernelBody frep
body) =
lvl
-> SegSpace
-> [HistOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist
(lvl
-> SegSpace
-> [HistOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep)
-> m lvl
-> m (SegSpace
-> [HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
m (SegSpace
-> [HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace
-> m ([HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
m ([HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m [HistOp trep]
-> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (HistOp frep -> m (HistOp trep))
-> [HistOp frep] -> m [HistOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp frep -> m (HistOp trep)
onHistOp [HistOp frep]
ops
m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp) -> TypeBase Shape u -> m (TypeBase Shape u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
where
onHistOp :: HistOp frep -> m (HistOp trep)
onHistOp (HistOp Shape
w SubExp
rf [VName]
arrs [SubExp]
nes Shape
shape Lambda frep
op) =
Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda trep
-> HistOp trep
forall rep.
Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda rep
-> HistOp rep
HistOp
(Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda trep
-> HistOp trep)
-> m Shape
-> m (SubExp
-> [VName] -> [SubExp] -> Shape -> Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> Shape -> m Shape
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) Shape
w
m (SubExp
-> [VName] -> [SubExp] -> Shape -> Lambda trep -> HistOp trep)
-> m SubExp
-> m ([VName] -> [SubExp] -> Shape -> Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv SubExp
rf
m ([VName] -> [SubExp] -> Shape -> Lambda trep -> HistOp trep)
-> m [VName] -> m ([SubExp] -> Shape -> Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl frep trep m -> VName -> m VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep m
tv) [VName]
arrs
m ([SubExp] -> Shape -> Lambda trep -> HistOp trep)
-> m [SubExp] -> m (Shape -> Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [SubExp]
nes
m (Shape -> Lambda trep -> HistOp trep)
-> m Shape -> m (Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> m [SubExp] -> m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape))
m (Lambda trep -> HistOp trep)
-> m (Lambda trep) -> m (HistOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda SegOpMapper lvl frep trep m
tv Lambda frep
op
mapOnSegOpType ::
Monad m =>
SegOpMapper lvl frep trep m ->
Type ->
m Type
mapOnSegOpType :: SegOpMapper lvl frep trep m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl frep trep m
_tv t :: Type
t@Prim {} = Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t
mapOnSegOpType SegOpMapper lvl frep trep m
tv (Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) =
VName -> Shape -> [Type] -> NoUniqueness -> Type
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc
(VName -> Shape -> [Type] -> NoUniqueness -> Type)
-> m VName -> m (Shape -> [Type] -> NoUniqueness -> Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> VName -> m VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep m
tv VName
acc
m (Shape -> [Type] -> NoUniqueness -> Type)
-> m Shape -> m ([Type] -> NoUniqueness -> Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> Shape -> m Shape
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) Shape
ispace
m ([Type] -> NoUniqueness -> Type)
-> m [Type] -> m (NoUniqueness -> Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Shape -> m Shape)
-> (NoUniqueness -> m NoUniqueness) -> Type -> m Type
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse ((SubExp -> m SubExp) -> Shape -> m Shape
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv)) NoUniqueness -> m NoUniqueness
forall (f :: * -> *) a. Applicative f => a -> f a
pure) [Type]
ts
m (NoUniqueness -> Type) -> m NoUniqueness -> m Type
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NoUniqueness -> m NoUniqueness
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
mapOnSegOpType SegOpMapper lvl frep trep m
tv (Array PrimType
et Shape
shape NoUniqueness
u) =
PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et (Shape -> NoUniqueness -> Type)
-> m Shape -> m (NoUniqueness -> Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> Shape -> m Shape
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) Shape
shape m (NoUniqueness -> Type) -> m NoUniqueness -> m Type
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NoUniqueness -> m NoUniqueness
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
mapOnSegOpType SegOpMapper lvl frep trep m
_tv (Mem Space
s) = Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
s
traverseSegOpStms :: Monad m => OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms :: OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms Scope rep -> Stms rep -> m (Stms rep)
f SegOp lvl rep
segop = SegOpMapper lvl rep rep m -> SegOp lvl rep -> m (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep m
mapper SegOp lvl rep
segop
where
seg_scope :: Scope rep
seg_scope = SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp lvl rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
segop)
f' :: Scope rep -> Stms rep -> m (Stms rep)
f' Scope rep
scope = Scope rep -> Stms rep -> m (Stms rep)
f (Scope rep
seg_scope Scope rep -> Scope rep -> Scope rep
forall a. Semigroup a => a -> a -> a
<> Scope rep
scope)
mapper :: SegOpMapper lvl rep rep m
mapper =
SegOpMapper lvl Any Any m
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
{ mapOnSegOpLambda :: Lambda rep -> m (Lambda rep)
mapOnSegOpLambda = OpStmsTraverser m (Lambda rep) rep
forall (m :: * -> *) rep.
Monad m =>
OpStmsTraverser m (Lambda rep) rep
traverseLambdaStms Scope rep -> Stms rep -> m (Stms rep)
f',
mapOnSegOpBody :: KernelBody rep -> m (KernelBody rep)
mapOnSegOpBody = KernelBody rep -> m (KernelBody rep)
onBody
}
onBody :: KernelBody rep -> m (KernelBody rep)
onBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec (Stms rep -> [KernelResult] -> KernelBody rep)
-> m (Stms rep) -> m ([KernelResult] -> KernelBody rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope rep -> Stms rep -> m (Stms rep)
f Scope rep
seg_scope Stms rep
stms m ([KernelResult] -> KernelBody rep)
-> m [KernelResult] -> m (KernelBody rep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> m [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res
instance
(ASTRep rep, Substitute lvl) =>
Substitute (SegOp lvl rep)
where
substituteNames :: Map VName VName -> SegOp lvl rep -> SegOp lvl rep
substituteNames Map VName VName
subst = Identity (SegOp lvl rep) -> SegOp lvl rep
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl rep) -> SegOp lvl rep)
-> (SegOp lvl rep -> Identity (SegOp lvl rep))
-> SegOp lvl rep
-> SegOp lvl rep
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl rep rep Identity
-> SegOp lvl rep -> Identity (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep Identity
substitute
where
substitute :: SegOpMapper lvl rep rep Identity
substitute =
SegOpMapper :: forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
{ mapOnSegOpSubExp :: SubExp -> Identity SubExp
mapOnSegOpSubExp = SubExp -> Identity SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpLambda :: Lambda rep -> Identity (Lambda rep)
mapOnSegOpLambda = Lambda rep -> Identity (Lambda rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep -> Identity (Lambda rep))
-> (Lambda rep -> Lambda rep)
-> Lambda rep
-> Identity (Lambda rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> Lambda rep -> Lambda rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpBody :: KernelBody rep -> Identity (KernelBody rep)
mapOnSegOpBody = KernelBody rep -> Identity (KernelBody rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody rep -> Identity (KernelBody rep))
-> (KernelBody rep -> KernelBody rep)
-> KernelBody rep
-> Identity (KernelBody rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> KernelBody rep -> KernelBody rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpVName :: VName -> Identity VName
mapOnSegOpVName = VName -> Identity VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> Identity VName)
-> (VName -> VName) -> VName -> Identity VName
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSegOpLevel :: lvl -> Identity lvl
mapOnSegOpLevel = lvl -> Identity lvl
forall (f :: * -> *) a. Applicative f => a -> f a
pure (lvl -> Identity lvl) -> (lvl -> lvl) -> lvl -> Identity lvl
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> lvl -> lvl
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
}
instance (ASTRep rep, ASTConstraints lvl) => Rename (SegOp lvl rep) where
rename :: SegOp lvl rep -> RenameM (SegOp lvl rep)
rename SegOp lvl rep
op =
[VName] -> RenameM (SegOp lvl rep) -> RenameM (SegOp lvl rep)
forall a. [VName] -> RenameM a -> RenameM a
renameBound (Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp lvl rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
op))) (RenameM (SegOp lvl rep) -> RenameM (SegOp lvl rep))
-> RenameM (SegOp lvl rep) -> RenameM (SegOp lvl rep)
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl rep rep RenameM
-> SegOp lvl rep -> RenameM (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep RenameM
renamer SegOp lvl rep
op
where
renamer :: SegOpMapper lvl rep rep RenameM
renamer = (SubExp -> RenameM SubExp)
-> (Lambda rep -> RenameM (Lambda rep))
-> (KernelBody rep -> RenameM (KernelBody rep))
-> (VName -> RenameM VName)
-> (lvl -> RenameM lvl)
-> SegOpMapper lvl rep rep RenameM
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename Lambda rep -> RenameM (Lambda rep)
forall a. Rename a => a -> RenameM a
rename KernelBody rep -> RenameM (KernelBody rep)
forall a. Rename a => a -> RenameM a
rename VName -> RenameM VName
forall a. Rename a => a -> RenameM a
rename lvl -> RenameM lvl
forall a. Rename a => a -> RenameM a
rename
instance
(ASTRep rep, FreeIn (LParamInfo rep), FreeIn lvl) =>
FreeIn (SegOp lvl rep)
where
freeIn' :: SegOp lvl rep -> FV
freeIn' SegOp lvl rep
e =
Names -> FV -> FV
fvBind ([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp lvl rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
e)) (FV -> FV) -> FV -> FV
forall a b. (a -> b) -> a -> b
$
(State FV (SegOp lvl rep) -> FV -> FV)
-> FV -> State FV (SegOp lvl rep) -> FV
forall a b c. (a -> b -> c) -> b -> a -> c
flip State FV (SegOp lvl rep) -> FV -> FV
forall s a. State s a -> s -> s
execState FV
forall a. Monoid a => a
mempty (State FV (SegOp lvl rep) -> FV) -> State FV (SegOp lvl rep) -> FV
forall a b. (a -> b) -> a -> b
$
SegOpMapper lvl rep rep (StateT FV Identity)
-> SegOp lvl rep -> State FV (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep (StateT FV Identity)
free SegOp lvl rep
e
where
walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = (s -> s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (s -> s -> s
forall a. Semigroup a => a -> a -> a
<> b -> s
f b
x) m () -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
x
free :: SegOpMapper lvl rep rep (StateT FV Identity)
free =
SegOpMapper :: forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
{ mapOnSegOpSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSegOpSubExp = (SubExp -> FV) -> SubExp -> StateT FV Identity SubExp
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpLambda :: Lambda rep -> StateT FV Identity (Lambda rep)
mapOnSegOpLambda = (Lambda rep -> FV) -> Lambda rep -> StateT FV Identity (Lambda rep)
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk Lambda rep -> FV
forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpBody :: KernelBody rep -> StateT FV Identity (KernelBody rep)
mapOnSegOpBody = (KernelBody rep -> FV)
-> KernelBody rep -> StateT FV Identity (KernelBody rep)
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk KernelBody rep -> FV
forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpVName :: VName -> StateT FV Identity VName
mapOnSegOpVName = (VName -> FV) -> VName -> StateT FV Identity VName
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk VName -> FV
forall a. FreeIn a => a -> FV
freeIn',
mapOnSegOpLevel :: lvl -> StateT FV Identity lvl
mapOnSegOpLevel = (lvl -> FV) -> lvl -> StateT FV Identity lvl
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk lvl -> FV
forall a. FreeIn a => a -> FV
freeIn'
}
instance OpMetrics (Op rep) => OpMetrics (SegOp lvl rep) where
opMetrics :: SegOp lvl rep -> MetricsM ()
opMetrics (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegMap" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
opMetrics (SegRed lvl
_ SegSpace
_ [SegBinOp rep]
reds [Type]
_ KernelBody rep
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegRed" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
(SegBinOp rep -> MetricsM ()) -> [SegBinOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (SegBinOp rep -> Lambda rep) -> SegBinOp rep -> MetricsM ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp rep]
reds
KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
opMetrics (SegScan lvl
_ SegSpace
_ [SegBinOp rep]
scans [Type]
_ KernelBody rep
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegScan" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
(SegBinOp rep -> MetricsM ()) -> [SegBinOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (SegBinOp rep -> Lambda rep) -> SegBinOp rep -> MetricsM ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp rep]
scans
KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
opMetrics (SegHist lvl
_ SegSpace
_ [HistOp rep]
ops [Type]
_ KernelBody rep
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegHist" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
(HistOp rep -> MetricsM ()) -> [HistOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (HistOp rep -> Lambda rep) -> HistOp rep -> MetricsM ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp) [HistOp rep]
ops
KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
instance Pretty SegSpace where
ppr :: SegSpace -> Doc
ppr (SegSpace VName
phys [(VName, SubExp)]
dims) =
Doc -> Doc
parens
( [Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ do
(VName
i, SubExp
d) <- [(VName, SubExp)]
dims
Doc -> [Doc]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Doc -> [Doc]) -> Doc -> [Doc]
forall a b. (a -> b) -> a -> b
$ VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
i Doc -> Doc -> Doc
<+> Doc
"<" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
d
)
Doc -> Doc -> Doc
<+> Doc -> Doc
parens (String -> Doc
text String
"~" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
phys)
instance PrettyRep rep => Pretty (SegBinOp rep) where
ppr :: SegBinOp rep -> Doc
ppr (SegBinOp Commutativity
comm Lambda rep
lam [SubExp]
nes Shape
shape) =
Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
Doc -> Doc -> Doc
</> Shape -> Doc
forall a. Pretty a => a -> Doc
ppr Shape
shape Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
Doc -> Doc -> Doc
</> Doc
comm' Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
lam
where
comm' :: Doc
comm' = case Commutativity
comm of
Commutativity
Commutative -> String -> Doc
text String
"commutative "
Commutativity
Noncommutative -> Doc
forall a. Monoid a => a
mempty
instance (PrettyRep rep, PP.Pretty lvl) => PP.Pretty (SegOp lvl rep) where
ppr :: SegOp lvl rep -> Doc
ppr (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody rep
body) =
String -> Doc
text String
"segmap" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl
Doc -> Doc -> Doc
</> Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space)
Doc -> Doc -> Doc
<+> Doc
PP.colon
Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts
Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody rep -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody rep
body)
ppr (SegRed lvl
lvl SegSpace
space [SegBinOp rep]
reds [Type]
ts KernelBody rep
body) =
String -> Doc
text String
"segred" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl
Doc -> Doc -> Doc
</> Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space)
Doc -> Doc -> Doc
</> Doc -> Doc
PP.parens ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (SegBinOp rep -> Doc) -> [SegBinOp rep] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp rep -> Doc
forall a. Pretty a => a -> Doc
ppr [SegBinOp rep]
reds)
Doc -> Doc -> Doc
</> Doc
PP.colon
Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts
Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody rep -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody rep
body)
ppr (SegScan lvl
lvl SegSpace
space [SegBinOp rep]
scans [Type]
ts KernelBody rep
body) =
String -> Doc
text String
"segscan" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl
Doc -> Doc -> Doc
</> Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space)
Doc -> Doc -> Doc
</> Doc -> Doc
PP.parens ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (SegBinOp rep -> Doc) -> [SegBinOp rep] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp rep -> Doc
forall a. Pretty a => a -> Doc
ppr [SegBinOp rep]
scans)
Doc -> Doc -> Doc
</> Doc
PP.colon
Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts
Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody rep -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody rep
body)
ppr (SegHist lvl
lvl SegSpace
space [HistOp rep]
ops [Type]
ts KernelBody rep
body) =
String -> Doc
text String
"seghist" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl
Doc -> Doc -> Doc
</> Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space)
Doc -> Doc -> Doc
</> Doc -> Doc
PP.parens ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> Doc) -> [HistOp rep] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map HistOp rep -> Doc
forall rep. PrettyRep rep => HistOp rep -> Doc
ppOp [HistOp rep]
ops)
Doc -> Doc -> Doc
</> Doc
PP.colon
Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts
Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody rep -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody rep
body)
where
ppOp :: HistOp rep -> Doc
ppOp (HistOp Shape
w SubExp
rf [VName]
dests [SubExp]
nes Shape
shape Lambda rep
op) =
Shape -> Doc
forall a. Pretty a => a -> Doc
ppr Shape
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
rf Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
ppr [VName]
dests) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
Doc -> Doc -> Doc
</> Shape -> Doc
forall a. Pretty a => a -> Doc
ppr Shape
shape Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
op
instance
( ASTRep rep,
ASTRep (Aliases rep),
CanBeAliased (Op rep),
ASTConstraints lvl
) =>
CanBeAliased (SegOp lvl rep)
where
type OpWithAliases (SegOp lvl rep) = SegOp lvl (Aliases rep)
addOpAliases :: AliasTable -> SegOp lvl rep -> OpWithAliases (SegOp lvl rep)
addOpAliases AliasTable
aliases = Identity (SegOp lvl (Aliases rep)) -> SegOp lvl (Aliases rep)
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl (Aliases rep)) -> SegOp lvl (Aliases rep))
-> (SegOp lvl rep -> Identity (SegOp lvl (Aliases rep)))
-> SegOp lvl rep
-> SegOp lvl (Aliases rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl rep (Aliases rep) Identity
-> SegOp lvl rep -> Identity (SegOp lvl (Aliases rep))
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep (Aliases rep) Identity
alias
where
alias :: SegOpMapper lvl rep (Aliases rep) Identity
alias =
(SubExp -> Identity SubExp)
-> (Lambda rep -> Identity (Lambda (Aliases rep)))
-> (KernelBody rep -> Identity (KernelBody (Aliases rep)))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl rep (Aliases rep) Identity
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
SubExp -> Identity SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure
(Lambda (Aliases rep) -> Identity (Lambda (Aliases rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Aliases rep) -> Identity (Lambda (Aliases rep)))
-> (Lambda rep -> Lambda (Aliases rep))
-> Lambda rep
-> Identity (Lambda (Aliases rep))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases)
(KernelBody (Aliases rep) -> Identity (KernelBody (Aliases rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody (Aliases rep) -> Identity (KernelBody (Aliases rep)))
-> (KernelBody rep -> KernelBody (Aliases rep))
-> KernelBody rep
-> Identity (KernelBody (Aliases rep))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
aliases)
VName -> Identity VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure
lvl -> Identity lvl
forall (f :: * -> *) a. Applicative f => a -> f a
pure
removeOpAliases :: OpWithAliases (SegOp lvl rep) -> SegOp lvl rep
removeOpAliases = Identity (SegOp lvl rep) -> SegOp lvl rep
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl rep) -> SegOp lvl rep)
-> (SegOp lvl (Aliases rep) -> Identity (SegOp lvl rep))
-> SegOp lvl (Aliases rep)
-> SegOp lvl rep
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl (Aliases rep) rep Identity
-> SegOp lvl (Aliases rep) -> Identity (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl (Aliases rep) rep Identity
forall lvl. SegOpMapper lvl (Aliases rep) rep Identity
remove
where
remove :: SegOpMapper lvl (Aliases rep) rep Identity
remove =
(SubExp -> Identity SubExp)
-> (Lambda (Aliases rep) -> Identity (Lambda rep))
-> (KernelBody (Aliases rep) -> Identity (KernelBody rep))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl (Aliases rep) rep Identity
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
SubExp -> Identity SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure
(Lambda rep -> Identity (Lambda rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep -> Identity (Lambda rep))
-> (Lambda (Aliases rep) -> Lambda rep)
-> Lambda (Aliases rep)
-> Identity (Lambda rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda (Aliases rep) -> Lambda rep
forall rep.
CanBeAliased (Op rep) =>
Lambda (Aliases rep) -> Lambda rep
removeLambdaAliases)
(KernelBody rep -> Identity (KernelBody rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody rep -> Identity (KernelBody rep))
-> (KernelBody (Aliases rep) -> KernelBody rep)
-> KernelBody (Aliases rep)
-> Identity (KernelBody rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody (Aliases rep) -> KernelBody rep
forall rep.
CanBeAliased (Op rep) =>
KernelBody (Aliases rep) -> KernelBody rep
removeKernelBodyAliases)
VName -> Identity VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure
lvl -> Identity lvl
forall (f :: * -> *) a. Applicative f => a -> f a
pure
informKernelBody :: Informing rep => KernelBody rep -> KernelBody (Wise rep)
informKernelBody :: KernelBody rep -> KernelBody (Wise rep)
informKernelBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody BodyDec rep
dec (Stms rep -> Stms (Wise rep)
forall rep. Informing rep => Stms rep -> Stms (Wise rep)
informStms Stms rep
stms) [KernelResult]
res
instance
(CanBeWise (Op rep), ASTRep rep, ASTConstraints lvl) =>
CanBeWise (SegOp lvl rep)
where
type OpWithWisdom (SegOp lvl rep) = SegOp lvl (Wise rep)
removeOpWisdom :: OpWithWisdom (SegOp lvl rep) -> SegOp lvl rep
removeOpWisdom = Identity (SegOp lvl rep) -> SegOp lvl rep
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl rep) -> SegOp lvl rep)
-> (SegOp lvl (Wise rep) -> Identity (SegOp lvl rep))
-> SegOp lvl (Wise rep)
-> SegOp lvl rep
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl (Wise rep) rep Identity
-> SegOp lvl (Wise rep) -> Identity (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl (Wise rep) rep Identity
forall lvl. SegOpMapper lvl (Wise rep) rep Identity
remove
where
remove :: SegOpMapper lvl (Wise rep) rep Identity
remove =
(SubExp -> Identity SubExp)
-> (Lambda (Wise rep) -> Identity (Lambda rep))
-> (KernelBody (Wise rep) -> Identity (KernelBody rep))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl (Wise rep) rep Identity
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
SubExp -> Identity SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure
(Lambda rep -> Identity (Lambda rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep -> Identity (Lambda rep))
-> (Lambda (Wise rep) -> Lambda rep)
-> Lambda (Wise rep)
-> Identity (Lambda rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda (Wise rep) -> Lambda rep
forall rep. CanBeWise (Op rep) => Lambda (Wise rep) -> Lambda rep
removeLambdaWisdom)
(KernelBody rep -> Identity (KernelBody rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody rep -> Identity (KernelBody rep))
-> (KernelBody (Wise rep) -> KernelBody rep)
-> KernelBody (Wise rep)
-> Identity (KernelBody rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody (Wise rep) -> KernelBody rep
forall rep.
CanBeWise (Op rep) =>
KernelBody (Wise rep) -> KernelBody rep
removeKernelBodyWisdom)
VName -> Identity VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure
lvl -> Identity lvl
forall (f :: * -> *) a. Applicative f => a -> f a
pure
addOpWisdom :: SegOp lvl rep -> OpWithWisdom (SegOp lvl rep)
addOpWisdom = Identity (SegOp lvl (Wise rep)) -> SegOp lvl (Wise rep)
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl (Wise rep)) -> SegOp lvl (Wise rep))
-> (SegOp lvl rep -> Identity (SegOp lvl (Wise rep)))
-> SegOp lvl rep
-> SegOp lvl (Wise rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl rep (Wise rep) Identity
-> SegOp lvl rep -> Identity (SegOp lvl (Wise rep))
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep (Wise rep) Identity
forall lvl. SegOpMapper lvl rep (Wise rep) Identity
add
where
add :: SegOpMapper lvl rep (Wise rep) Identity
add =
(SubExp -> Identity SubExp)
-> (Lambda rep -> Identity (Lambda (Wise rep)))
-> (KernelBody rep -> Identity (KernelBody (Wise rep)))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl rep (Wise rep) Identity
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
SubExp -> Identity SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure
(Lambda (Wise rep) -> Identity (Lambda (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Wise rep) -> Identity (Lambda (Wise rep)))
-> (Lambda rep -> Lambda (Wise rep))
-> Lambda rep
-> Identity (Lambda (Wise rep))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> Lambda (Wise rep)
forall rep. Informing rep => Lambda rep -> Lambda (Wise rep)
informLambda)
(KernelBody (Wise rep) -> Identity (KernelBody (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody (Wise rep) -> Identity (KernelBody (Wise rep)))
-> (KernelBody rep -> KernelBody (Wise rep))
-> KernelBody rep
-> Identity (KernelBody (Wise rep))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody rep -> KernelBody (Wise rep)
forall rep.
Informing rep =>
KernelBody rep -> KernelBody (Wise rep)
informKernelBody)
VName -> Identity VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure
lvl -> Identity lvl
forall (f :: * -> *) a. Applicative f => a -> f a
pure
instance ASTRep rep => ST.IndexOp (SegOp lvl rep) where
indexOp :: SymbolTable rep
-> Int -> SegOp lvl rep -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k (SegMap lvl
_ SegSpace
space [Type]
_ KernelBody rep
kbody) [TPrimExp Int64 VName]
is = do
Returns ResultManifest
ResultMaySimplify Certs
_ SubExp
se <- Int -> [KernelResult] -> Maybe KernelResult
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
k ([KernelResult] -> Maybe KernelResult)
-> [KernelResult] -> Maybe KernelResult
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
is
let idx_table :: Map VName Indexed
idx_table = [(VName, Indexed)] -> Map VName Indexed
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Indexed)] -> Map VName Indexed)
-> [(VName, Indexed)] -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ [VName] -> [Indexed] -> [(VName, Indexed)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids ([Indexed] -> [(VName, Indexed)])
-> [Indexed] -> [(VName, Indexed)]
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> Indexed)
-> [TPrimExp Int64 VName] -> [Indexed]
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> PrimExp VName -> Indexed
ST.Indexed Certs
forall a. Monoid a => a
mempty (PrimExp VName -> Indexed)
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> Indexed
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped) [TPrimExp Int64 VName]
is
idx_table' :: Map VName Indexed
idx_table' = (Map VName Indexed -> Stm rep -> Map VName Indexed)
-> Map VName Indexed -> Seq (Stm rep) -> Map VName Indexed
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map VName Indexed -> Stm rep -> Map VName Indexed
expandIndexedTable Map VName Indexed
idx_table (Seq (Stm rep) -> Map VName Indexed)
-> Seq (Stm rep) -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> Seq (Stm rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
kbody
case SubExp
se of
Var VName
v -> VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
idx_table'
SubExp
_ -> Maybe Indexed
forall a. Maybe a
Nothing
where
([VName]
gtids, [SubExp]
_) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
excess_is :: [TPrimExp Int64 VName]
excess_is = Int -> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. Int -> [a] -> [a]
drop ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) [TPrimExp Int64 VName]
is
expandIndexedTable :: Map VName Indexed -> Stm rep -> Map VName Indexed
expandIndexedTable Map VName Indexed
table Stm rep
stm
| [VName
v] <- Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName]) -> Pat (LetDec rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
Just (PrimExp VName
pe, Certs
cs) <-
WriterT Certs Maybe (PrimExp VName) -> Maybe (PrimExp VName, Certs)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certs Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certs))
-> WriterT Certs Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certs)
forall a b. (a -> b) -> a -> b
$ (VName -> WriterT Certs Maybe (PrimExp VName))
-> Exp rep -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) rep v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table) (Exp rep -> WriterT Certs Maybe (PrimExp VName))
-> Exp rep -> WriterT Certs Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Certs -> PrimExp VName -> Indexed
ST.Indexed (Stm rep -> Certs
forall rep. Stm rep -> Certs
stmCerts Stm rep
stm Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) PrimExp VName
pe) Map VName Indexed
table
| [VName
v] <- Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName]) -> Pat (LetDec rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
BasicOp (Index VName
arr Slice SubExp
slice) <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
[SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
excess_is,
VName
arr VName -> SymbolTable rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.available` SymbolTable rep
vtable,
Just (Slice (PrimExp VName)
slice', Certs
cs) <- Map VName Indexed
-> Slice SubExp -> Maybe (Slice (PrimExp VName), Certs)
asPrimExpSlice Map VName Indexed
table Slice SubExp
slice =
let idx :: Indexed
idx =
Certs -> VName -> [TPrimExp Int64 VName] -> Indexed
ST.IndexedArray
(Stm rep -> Certs
forall rep. Stm rep -> Certs
stmCerts Stm rep
stm Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs)
VName
arr
(Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((PrimExp VName -> TPrimExp Int64 VName)
-> Slice (PrimExp VName) -> Slice (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 Slice (PrimExp VName)
slice') [TPrimExp Int64 VName]
excess_is)
in VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Indexed
idx Map VName Indexed
table
| Bool
otherwise =
Map VName Indexed
table
asPrimExpSlice :: Map VName Indexed
-> Slice SubExp -> Maybe (Slice (PrimExp VName), Certs)
asPrimExpSlice Map VName Indexed
table =
WriterT Certs Maybe (Slice (PrimExp VName))
-> Maybe (Slice (PrimExp VName), Certs)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certs Maybe (Slice (PrimExp VName))
-> Maybe (Slice (PrimExp VName), Certs))
-> (Slice SubExp -> WriterT Certs Maybe (Slice (PrimExp VName)))
-> Slice SubExp
-> Maybe (Slice (PrimExp VName), Certs)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SubExp -> WriterT Certs Maybe (PrimExp VName))
-> Slice SubExp -> WriterT Certs Maybe (Slice (PrimExp VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((VName -> WriterT Certs Maybe (PrimExp VName))
-> SubExp -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM (Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table))
asPrimExp :: Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table VName
v
| Just (ST.Indexed Certs
cs PrimExp VName
e) <- VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
table = Certs -> WriterT Certs Maybe ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certs
cs WriterT Certs Maybe ()
-> WriterT Certs Maybe (PrimExp VName)
-> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimExp VName
e
| Just (Prim PrimType
pt) <- VName -> SymbolTable rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable =
PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimExp VName -> WriterT Certs Maybe (PrimExp VName))
-> PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
| Bool
otherwise = Maybe (PrimExp VName) -> WriterT Certs Maybe (PrimExp VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe (PrimExp VName)
forall a. Maybe a
Nothing
indexOp SymbolTable rep
_ Int
_ SegOp lvl rep
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing
instance
(ASTRep rep, ASTConstraints lvl) =>
IsOp (SegOp lvl rep)
where
cheapOp :: SegOp lvl rep -> Bool
cheapOp SegOp lvl rep
_ = Bool
False
safeOp :: SegOp lvl rep -> Bool
safeOp SegOp lvl rep
_ = Bool
True
instance Engine.Simplifiable SplitOrdering where
simplify :: SplitOrdering -> SimpleM rep SplitOrdering
simplify SplitOrdering
SplitContiguous =
SplitOrdering -> SimpleM rep SplitOrdering
forall (f :: * -> *) a. Applicative f => a -> f a
pure SplitOrdering
SplitContiguous
simplify (SplitStrided SubExp
stride) =
SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering)
-> SimpleM rep SubExp -> SimpleM rep SplitOrdering
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
stride
instance Engine.Simplifiable SegSpace where
simplify :: SegSpace -> SimpleM rep SegSpace
simplify (SegSpace VName
phys [(VName, SubExp)]
dims) =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
phys ([(VName, SubExp)] -> SegSpace)
-> SimpleM rep [(VName, SubExp)] -> SimpleM rep SegSpace
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, SubExp) -> SimpleM rep (VName, SubExp))
-> [(VName, SubExp)] -> SimpleM rep [(VName, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> SimpleM rep SubExp)
-> (VName, SubExp) -> SimpleM rep (VName, SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify) [(VName, SubExp)]
dims
instance Engine.Simplifiable KernelResult where
simplify :: KernelResult -> SimpleM rep KernelResult
simplify (Returns ResultManifest
manifest Certs
cs SubExp
what) =
ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest (Certs -> SubExp -> KernelResult)
-> SimpleM rep Certs -> SimpleM rep (SubExp -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs SimpleM rep (SubExp -> KernelResult)
-> SimpleM rep SubExp -> SimpleM rep KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
what
simplify (WriteReturns Certs
cs Shape
ws VName
a [(Slice SubExp, SubExp)]
res) =
Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns
(Certs
-> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM rep Certs
-> SimpleM
rep (Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs
SimpleM
rep (Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM rep Shape
-> SimpleM rep (VName -> [(Slice SubExp, SubExp)] -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Shape -> SimpleM rep Shape
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Shape
ws
SimpleM rep (VName -> [(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM rep VName
-> SimpleM rep ([(Slice SubExp, SubExp)] -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
a
SimpleM rep ([(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM rep [(Slice SubExp, SubExp)] -> SimpleM rep KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(Slice SubExp, SubExp)] -> SimpleM rep [(Slice SubExp, SubExp)]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(Slice SubExp, SubExp)]
res
simplify (ConcatReturns Certs
cs SplitOrdering
o SubExp
w SubExp
pte VName
what) =
Certs -> SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns
(Certs
-> SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult)
-> SimpleM rep Certs
-> SimpleM
rep (SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs
SimpleM
rep (SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult)
-> SimpleM rep SplitOrdering
-> SimpleM rep (SubExp -> SubExp -> VName -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SplitOrdering -> SimpleM rep SplitOrdering
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SplitOrdering
o
SimpleM rep (SubExp -> SubExp -> VName -> KernelResult)
-> SimpleM rep SubExp
-> SimpleM rep (SubExp -> VName -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
SimpleM rep (SubExp -> VName -> KernelResult)
-> SimpleM rep SubExp -> SimpleM rep (VName -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
pte
SimpleM rep (VName -> KernelResult)
-> SimpleM rep VName -> SimpleM rep KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
what
simplify (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
what) =
Certs -> [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns (Certs -> [(SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep Certs
-> SimpleM rep ([(SubExp, SubExp)] -> VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs SimpleM rep ([(SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep [(SubExp, SubExp)]
-> SimpleM rep (VName -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(SubExp, SubExp)] -> SimpleM rep [(SubExp, SubExp)]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(SubExp, SubExp)]
dims SimpleM rep (VName -> KernelResult)
-> SimpleM rep VName -> SimpleM rep KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
what
simplify (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
what) =
Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
(Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep Certs
-> SimpleM
rep ([(SubExp, SubExp, SubExp)] -> VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs
SimpleM rep ([(SubExp, SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep [(SubExp, SubExp, SubExp)]
-> SimpleM rep (VName -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(SubExp, SubExp, SubExp)]
-> SimpleM rep [(SubExp, SubExp, SubExp)]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(SubExp, SubExp, SubExp)]
dims_n_tiles
SimpleM rep (VName -> KernelResult)
-> SimpleM rep VName -> SimpleM rep KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
what
mkWiseKernelBody ::
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep ->
Stms (Wise rep) ->
[KernelResult] ->
KernelBody (Wise rep)
mkWiseKernelBody :: BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody BodyDec rep
dec Stms (Wise rep)
stms [KernelResult]
res =
let Body BodyDec (Wise rep)
dec' Stms (Wise rep)
_ Result
_ = BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
mkWiseBody BodyDec rep
dec Stms (Wise rep)
stms (Result -> Body (Wise rep)) -> Result -> Body (Wise rep)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
res_vs
in BodyDec (Wise rep)
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Wise rep)
dec' Stms (Wise rep)
stms [KernelResult]
res
where
res_vs :: [SubExp]
res_vs = (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
res
mkKernelBodyM ::
MonadBuilder m =>
Stms (Rep m) ->
[KernelResult] ->
m (KernelBody (Rep m))
mkKernelBodyM :: Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms (Rep m)
stms [KernelResult]
kres = do
Body BodyDec (Rep m)
dec' Stms (Rep m)
_ Result
_ <- Stms (Rep m) -> Result -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
stms (Result -> m (Body (Rep m))) -> Result -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
res_ses
KernelBody (Rep m) -> m (KernelBody (Rep m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody (Rep m) -> m (KernelBody (Rep m)))
-> KernelBody (Rep m) -> m (KernelBody (Rep m))
forall a b. (a -> b) -> a -> b
$ BodyDec (Rep m)
-> Stms (Rep m) -> [KernelResult] -> KernelBody (Rep m)
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Rep m)
dec' Stms (Rep m)
stms [KernelResult]
kres
where
res_ses :: [SubExp]
res_ses = (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
kres
simplifyKernelBody ::
(Engine.SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace ->
KernelBody (Wise rep) ->
Engine.SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody :: SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space (KernelBody BodyDec (Wise rep)
_ Stms (Wise rep)
stms [KernelResult]
res) = do
BlockPred (Wise rep)
par_blocker <- (Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep))
forall rep a. (Env rep -> a) -> SimpleM rep a
Engine.asksEngineEnv ((Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep)))
-> (Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep))
forall a b. (a -> b) -> a -> b
$ HoistBlockers rep -> BlockPred (Wise rep)
forall rep. HoistBlockers rep -> BlockPred (Wise rep)
Engine.blockHoistPar (HoistBlockers rep -> BlockPred (Wise rep))
-> (Env rep -> HoistBlockers rep)
-> Env rep
-> BlockPred (Wise rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Env rep -> HoistBlockers rep
forall rep. Env rep -> HoistBlockers rep
Engine.envHoistBlockers
let blocker :: BlockPred (Wise rep)
blocker =
Names -> BlockPred (Wise rep)
forall rep. ASTRep rep => Names -> BlockPred rep
Engine.hasFree Names
bound_here
BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. BlockPred rep
Engine.isOp
BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
par_blocker
BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. BlockPred rep
Engine.isConsumed
BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. Aliased rep => BlockPred rep
Engine.isConsuming
BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. SimplifiableRep rep => BlockPred (Wise rep)
Engine.isDeviceMigrated
([KernelResult]
body_res, Stms (Wise rep)
body_stms, Stms (Wise rep)
hoisted) <-
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable ((SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep))
-> [VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((SymbolTable (Wise rep) -> VName -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> VName -> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep. VName -> SymbolTable rep -> SymbolTable rep
ST.consume)) ((KernelResult -> [VName]) -> [KernelResult] -> [VName]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap KernelResult -> [VName]
consumedInResult [KernelResult]
res))
(SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable)
(SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True})
(SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall rep a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop
(SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep ([KernelResult], UsageTable)
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep (a, UsageTable)
-> SimpleM rep (a, Stms (Wise rep), Stms (Wise rep))
Engine.blockIf BlockPred (Wise rep)
blocker Stms (Wise rep)
stms
(SimpleM rep ([KernelResult], UsageTable)
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], UsageTable)
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ do
[KernelResult]
res' <-
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep [KernelResult] -> SimpleM rep [KernelResult]
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (Names -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep. Names -> SymbolTable rep -> SymbolTable rep
ST.hideCertified (Names -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> Names -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo (Wise rep)) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo (Wise rep)) -> [VName])
-> Map VName (NameInfo (Wise rep)) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms (Wise rep) -> Map VName (NameInfo (Wise rep))
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms (Wise rep)
stms) (SimpleM rep [KernelResult] -> SimpleM rep [KernelResult])
-> SimpleM rep [KernelResult] -> SimpleM rep [KernelResult]
forall a b. (a -> b) -> a -> b
$
(KernelResult -> SimpleM rep KernelResult)
-> [KernelResult] -> SimpleM rep [KernelResult]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelResult -> SimpleM rep KernelResult
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [KernelResult]
res
([KernelResult], UsageTable)
-> SimpleM rep ([KernelResult], UsageTable)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([KernelResult]
res', Names -> UsageTable
UT.usages (Names -> UsageTable) -> Names -> UsageTable
forall a b. (a -> b) -> a -> b
$ [KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn [KernelResult]
res')
(KernelBody (Wise rep), Stms (Wise rep))
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody () Stms (Wise rep)
body_stms [KernelResult]
body_res, Stms (Wise rep)
hoisted)
where
scope_vtable :: SymbolTable (Wise rep)
scope_vtable = SegSpace -> SymbolTable (Wise rep)
forall rep. ASTRep rep => SegSpace -> SymbolTable rep
segSpaceSymbolTable SegSpace
space
bound_here :: Names
bound_here = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
consumedInResult :: KernelResult -> [VName]
consumedInResult (WriteReturns Certs
_ Shape
_ VName
arr [(Slice SubExp, SubExp)]
_) =
[VName
arr]
consumedInResult KernelResult
_ =
[]
simplifyLambda ::
Engine.SimplifiableRep rep =>
Lambda (Wise rep) ->
Engine.SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda :: Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda = SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.blockMigrated (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> (Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda
segSpaceSymbolTable :: ASTRep rep => SegSpace -> ST.SymbolTable rep
segSpaceSymbolTable :: SegSpace -> SymbolTable rep
segSpaceSymbolTable (SegSpace VName
flat [(VName, SubExp)]
gtids_and_dims) =
(SymbolTable rep -> (VName, SubExp) -> SymbolTable rep)
-> SymbolTable rep -> [(VName, SubExp)] -> SymbolTable rep
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
forall rep.
ASTRep rep =>
SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
f (Scope rep -> SymbolTable rep
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope (Scope rep -> SymbolTable rep) -> Scope rep -> SymbolTable rep
forall a b. (a -> b) -> a -> b
$ VName -> NameInfo rep -> Scope rep
forall k a. k -> a -> Map k a
M.singleton VName
flat (NameInfo rep -> Scope rep) -> NameInfo rep -> Scope rep
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo rep
forall rep. IntType -> NameInfo rep
IndexName IntType
Int64) [(VName, SubExp)]
gtids_and_dims
where
f :: SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
f SymbolTable rep
vtable (VName
gtid, SubExp
dim) = VName -> IntType -> SubExp -> SymbolTable rep -> SymbolTable rep
forall rep.
ASTRep rep =>
VName -> IntType -> SubExp -> SymbolTable rep -> SymbolTable rep
ST.insertLoopVar VName
gtid IntType
Int64 SubExp
dim SymbolTable rep
vtable
simplifySegBinOp ::
Engine.SimplifiableRep rep =>
SegBinOp (Wise rep) ->
Engine.SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp :: SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp (SegBinOp Commutativity
comm Lambda (Wise rep)
lam [SubExp]
nes Shape
shape) = do
(Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <-
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True}) (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda Lambda (Wise rep)
lam
Shape
shape' <- Shape -> SimpleM rep Shape
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Shape
shape
[SubExp]
nes' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
(SegBinOp (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Commutativity
-> Lambda (Wise rep) -> [SubExp] -> Shape -> SegBinOp (Wise rep)
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm Lambda (Wise rep)
lam' [SubExp]
nes' Shape
shape', Stms (Wise rep)
hoisted)
simplifySegOp ::
( Engine.SimplifiableRep rep,
BodyDec rep ~ (),
Engine.Simplifiable lvl
) =>
SegOp lvl (Wise rep) ->
Engine.SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp :: SegOp lvl (Wise rep)
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Wise rep)
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
(KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
(SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( lvl
-> SegSpace
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl' SegSpace
space' [Type]
ts' KernelBody (Wise rep)
kbody',
Stms (Wise rep)
body_hoisted
)
simplifySegOp (SegRed lvl
lvl SegSpace
space [SegBinOp (Wise rep)]
reds [Type]
ts KernelBody (Wise rep)
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
([SegBinOp (Wise rep)]
reds', [Stms (Wise rep)]
reds_hoisted) <-
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) (SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
[(SegBinOp (Wise rep), Stms (Wise rep))]
-> ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegBinOp (Wise rep), Stms (Wise rep))]
-> ([SegBinOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(SegBinOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep)))
-> [SegBinOp (Wise rep)]
-> SimpleM rep [(SegBinOp (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp [SegBinOp (Wise rep)]
reds
(KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
(SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( lvl
-> SegSpace
-> [SegBinOp (Wise rep)]
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl' SegSpace
space' [SegBinOp (Wise rep)]
reds' [Type]
ts' KernelBody (Wise rep)
kbody',
[Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
reds_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
)
where
scope :: Scope (Wise rep)
scope = SegSpace -> Scope (Wise rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
scope_vtable :: SymbolTable (Wise rep)
scope_vtable = Scope (Wise rep) -> SymbolTable (Wise rep)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
simplifySegOp (SegScan lvl
lvl SegSpace
space [SegBinOp (Wise rep)]
scans [Type]
ts KernelBody (Wise rep)
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
([SegBinOp (Wise rep)]
scans', [Stms (Wise rep)]
scans_hoisted) <-
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) (SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
[(SegBinOp (Wise rep), Stms (Wise rep))]
-> ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegBinOp (Wise rep), Stms (Wise rep))]
-> ([SegBinOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(SegBinOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep)))
-> [SegBinOp (Wise rep)]
-> SimpleM rep [(SegBinOp (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp [SegBinOp (Wise rep)]
scans
(KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
(SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( lvl
-> SegSpace
-> [SegBinOp (Wise rep)]
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl' SegSpace
space' [SegBinOp (Wise rep)]
scans' [Type]
ts' KernelBody (Wise rep)
kbody',
[Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
scans_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
)
where
scope :: Scope (Wise rep)
scope = SegSpace -> Scope (Wise rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
scope_vtable :: SymbolTable (Wise rep)
scope_vtable = Scope (Wise rep) -> SymbolTable (Wise rep)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
simplifySegOp (SegHist lvl
lvl SegSpace
space [HistOp (Wise rep)]
ops [Type]
ts KernelBody (Wise rep)
kbody) = do
(lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
([HistOp (Wise rep)]
ops', [Stms (Wise rep)]
ops_hoisted) <- ([(HistOp (Wise rep), Stms (Wise rep))]
-> ([HistOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(HistOp (Wise rep), Stms (Wise rep))]
-> ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
[HistOp (Wise rep)]
-> (HistOp (Wise rep)
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Wise rep)]
ops ((HistOp (Wise rep)
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))])
-> (HistOp (Wise rep)
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
forall a b. (a -> b) -> a -> b
$
\(HistOp Shape
w SubExp
rf [VName]
arrs [SubExp]
nes Shape
dims Lambda (Wise rep)
lam) -> do
Shape
w' <- Shape -> SimpleM rep Shape
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Shape
w
SubExp
rf' <- SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
rf
[VName]
arrs' <- [VName] -> SimpleM rep [VName]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
arrs
[SubExp]
nes' <- [SubExp] -> SimpleM rep [SubExp]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
Shape
dims' <- Shape -> SimpleM rep Shape
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Shape
dims
(Lambda (Wise rep)
lam', Stms (Wise rep)
op_hoisted) <-
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True}) (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda Lambda (Wise rep)
lam
(HistOp (Wise rep), Stms (Wise rep))
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda (Wise rep)
-> HistOp (Wise rep)
forall rep.
Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda rep
-> HistOp rep
HistOp Shape
w' SubExp
rf' [VName]
arrs' [SubExp]
nes' Shape
dims' Lambda (Wise rep)
lam',
Stms (Wise rep)
op_hoisted
)
(KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
(SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( lvl
-> SegSpace
-> [HistOp (Wise rep)]
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl' SegSpace
space' [HistOp (Wise rep)]
ops' [Type]
ts' KernelBody (Wise rep)
kbody',
[Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
ops_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
)
where
scope :: Scope (Wise rep)
scope = SegSpace -> Scope (Wise rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
scope_vtable :: SymbolTable (Wise rep)
scope_vtable = Scope (Wise rep) -> SymbolTable (Wise rep)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
class HasSegOp rep where
type SegOpLevel rep
asSegOp :: Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
segOp :: SegOp (SegOpLevel rep) rep -> Op rep
segOpRules ::
(HasSegOp rep, BuilderOps rep, Buildable rep, Aliased rep) =>
RuleBook rep
segOpRules :: RuleBook rep
segOpRules =
[TopDownRule rep] -> [BottomUpRule rep] -> RuleBook rep
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [RuleOp rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp rep (TopDown rep)
forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
TopDownRuleOp rep
segOpRuleTopDown] [RuleOp rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp rep (BottomUp rep)
forall rep.
(HasSegOp rep, BuilderOps rep, Aliased rep) =>
BottomUpRuleOp rep
segOpRuleBottomUp]
segOpRuleTopDown ::
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
TopDownRuleOp rep
segOpRuleTopDown :: TopDownRuleOp rep
segOpRuleTopDown TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec Op rep
op
| Just SegOp (SegOpLevel rep) rep
op' <- Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
forall rep.
HasSegOp rep =>
Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
asSegOp Op rep
op =
TopDown rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
SymbolTable rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
topDownSegOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
op'
| Bool
otherwise =
Rule rep
forall rep. Rule rep
Skip
segOpRuleBottomUp ::
(HasSegOp rep, BuilderOps rep, Aliased rep) =>
BottomUpRuleOp rep
segOpRuleBottomUp :: BottomUpRuleOp rep
segOpRuleBottomUp BottomUp rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec Op rep
op
| Just SegOp (SegOpLevel rep) rep
op' <- Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
forall rep.
HasSegOp rep =>
Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
asSegOp Op rep
op =
BottomUp rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
forall rep.
(Aliased rep, HasSegOp rep, BuilderOps rep) =>
(SymbolTable rep, UsageTable)
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
bottomUpSegOp BottomUp rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
op'
| Bool
otherwise =
Rule rep
forall rep. Rule rep
Skip
topDownSegOp ::
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
ST.SymbolTable rep ->
Pat (LetDec rep) ->
StmAux (ExpDec rep) ->
SegOp (SegOpLevel rep) rep ->
Rule rep
topDownSegOp :: SymbolTable rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
topDownSegOp SymbolTable rep
vtable (Pat [PatElem (LetDec rep)]
kpes) StmAux (ExpDec rep)
dec (SegMap SegOpLevel rep
lvl SegSpace
space [Type]
ts (KernelBody BodyDec rep
_ Stms rep
kstms [KernelResult]
kres)) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
([Type]
ts', [PatElem (LetDec rep)]
kpes', [KernelResult]
kres') <-
[(Type, PatElem (LetDec rep), KernelResult)]
-> ([Type], [PatElem (LetDec rep)], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Type, PatElem (LetDec rep), KernelResult)]
-> ([Type], [PatElem (LetDec rep)], [KernelResult]))
-> RuleM rep [(Type, PatElem (LetDec rep), KernelResult)]
-> RuleM rep ([Type], [PatElem (LetDec rep)], [KernelResult])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool)
-> [(Type, PatElem (LetDec rep), KernelResult)]
-> RuleM rep [(Type, PatElem (LetDec rep), KernelResult)]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool
checkForInvarianceResult ([Type]
-> [PatElem (LetDec rep)]
-> [KernelResult]
-> [(Type, PatElem (LetDec rep), KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
ts [PatElem (LetDec rep)]
kpes [KernelResult]
kres)
Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([KernelResult]
kres [KernelResult] -> [KernelResult] -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult]
kres') RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify
KernelBody rep
kbody <- Stms (Rep (RuleM rep))
-> [KernelResult] -> RuleM rep (KernelBody (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms rep
Stms (Rep (RuleM rep))
kstms [KernelResult]
kres'
Stm (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (RuleM rep)) -> RuleM rep ())
-> Stm (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
kpes') StmAux (ExpDec rep)
dec (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$
Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (Op rep -> Exp rep) -> Op rep -> Exp rep
forall a b. (a -> b) -> a -> b
$
SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> Op rep)
-> SegOp (SegOpLevel rep) rep -> Op rep
forall a b. (a -> b) -> a -> b
$
SegOpLevel rep
-> SegSpace
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space [Type]
ts' KernelBody rep
kbody
where
isInvariant :: SubExp -> Bool
isInvariant Constant {} = Bool
True
isInvariant (Var VName
v) = Maybe (Entry rep) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry rep) -> Bool) -> Maybe (Entry rep) -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> SymbolTable rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
v SymbolTable rep
vtable
checkForInvarianceResult :: (Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool
checkForInvarianceResult (Type
_, PatElem (LetDec rep)
pe, Returns ResultManifest
rm Certs
cs SubExp
se)
| Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty,
ResultManifest
rm ResultManifest -> ResultManifest -> Bool
forall a. Eq a => a -> a -> Bool
== ResultManifest
ResultMaySimplify,
SubExp -> Bool
isInvariant SubExp
se = do
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$
Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space) SubExp
se
Bool -> RuleM rep Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
checkForInvarianceResult (Type, PatElem (LetDec rep), KernelResult)
_ =
Bool -> RuleM rep Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
topDownSegOp SymbolTable rep
_ (Pat [PatElem (LetDec rep)]
pes) StmAux (ExpDec rep)
_ (SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
ts KernelBody rep
kbody)
| [SegBinOp rep] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegBinOp rep]
ops Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
[[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings <-
((SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
-> (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
-> Bool)
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
-> (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
-> Bool
forall rep b rep b. (SegBinOp rep, b) -> (SegBinOp rep, b) -> Bool
sameShape ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> [[(SegBinOp rep,
[(PatElem (LetDec rep), Type, KernelResult)])]])
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
forall a b. (a -> b) -> a -> b
$
[SegBinOp rep]
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp rep]
ops ([[(PatElem (LetDec rep), Type, KernelResult)]]
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])])
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
forall a b. (a -> b) -> a -> b
$
[Int]
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp rep -> Int) -> [SegBinOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp rep -> [SubExp]) -> SegBinOp rep -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp rep]
ops) ([(PatElem (LetDec rep), Type, KernelResult)]
-> [[(PatElem (LetDec rep), Type, KernelResult)]])
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
forall a b. (a -> b) -> a -> b
$
[PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> [(PatElem (LetDec rep), Type, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
red_pes [Type]
red_ts [KernelResult]
red_res,
([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Bool)
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
-> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1) (Int -> Bool)
-> ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Int)
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
let ([SegBinOp rep]
ops', [[(PatElem (LetDec rep), Type, KernelResult)]]
aux) = [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> ([SegBinOp rep], [[(PatElem (LetDec rep), Type, KernelResult)]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> ([SegBinOp rep],
[[(PatElem (LetDec rep), Type, KernelResult)]]))
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> ([SegBinOp rep], [[(PatElem (LetDec rep), Type, KernelResult)]])
forall a b. (a -> b) -> a -> b
$ ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Maybe
(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)]))
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Maybe
(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
forall rep a.
Buildable rep =>
[(SegBinOp rep, [a])] -> Maybe (SegBinOp rep, [a])
combineOps [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings
([PatElem (LetDec rep)]
red_pes', [Type]
red_ts', [KernelResult]
red_res') = [(PatElem (LetDec rep), Type, KernelResult)]
-> ([PatElem (LetDec rep)], [Type], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(PatElem (LetDec rep), Type, KernelResult)]
-> ([PatElem (LetDec rep)], [Type], [KernelResult]))
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> ([PatElem (LetDec rep)], [Type], [KernelResult])
forall a b. (a -> b) -> a -> b
$ [[(PatElem (LetDec rep), Type, KernelResult)]]
-> [(PatElem (LetDec rep), Type, KernelResult)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(PatElem (LetDec rep), Type, KernelResult)]]
aux
pes' :: [PatElem (LetDec rep)]
pes' = [PatElem (LetDec rep)]
red_pes' [PatElem (LetDec rep)]
-> [PatElem (LetDec rep)] -> [PatElem (LetDec rep)]
forall a. [a] -> [a] -> [a]
++ [PatElem (LetDec rep)]
map_pes
ts' :: [Type]
ts' = [Type]
red_ts' [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
map_ts
kbody' :: KernelBody rep
kbody' = KernelBody rep
kbody {kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult]
red_res' [KernelResult] -> [KernelResult] -> [KernelResult]
forall a. [a] -> [a] -> [a]
++ [KernelResult]
map_res}
Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pes') (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (Op rep -> Exp rep) -> Op rep -> Exp rep
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> Op rep)
-> SegOp (SegOpLevel rep) rep -> Op rep
forall a b. (a -> b) -> a -> b
$ SegOpLevel rep
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops' [Type]
ts' KernelBody rep
kbody'
where
([PatElem (LetDec rep)]
red_pes, [PatElem (LetDec rep)]
map_pes) = Int
-> [PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) [PatElem (LetDec rep)]
pes
([Type]
red_ts, [Type]
map_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) [Type]
ts
([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
sameShape :: (SegBinOp rep, b) -> (SegBinOp rep, b) -> Bool
sameShape (SegBinOp rep
op1, b
_) (SegBinOp rep
op2, b
_) = SegBinOp rep -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp rep
op1 Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
== SegBinOp rep -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp rep
op2
combineOps :: [(SegBinOp rep, [a])] -> Maybe (SegBinOp rep, [a])
combineOps [] = Maybe (SegBinOp rep, [a])
forall a. Maybe a
Nothing
combineOps ((SegBinOp rep, [a])
x : [(SegBinOp rep, [a])]
xs) = (SegBinOp rep, [a]) -> Maybe (SegBinOp rep, [a])
forall a. a -> Maybe a
Just ((SegBinOp rep, [a]) -> Maybe (SegBinOp rep, [a]))
-> (SegBinOp rep, [a]) -> Maybe (SegBinOp rep, [a])
forall a b. (a -> b) -> a -> b
$ ((SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a]))
-> (SegBinOp rep, [a])
-> [(SegBinOp rep, [a])]
-> (SegBinOp rep, [a])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
forall rep a.
Buildable rep =>
(SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
combine (SegBinOp rep, [a])
x [(SegBinOp rep, [a])]
xs
combine :: (SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
combine (SegBinOp rep
op1, [a]
op1_aux) (SegBinOp rep
op2, [a]
op2_aux) =
let lam1 :: Lambda rep
lam1 = SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op1
lam2 :: Lambda rep
lam2 = SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op2
([Param (LParamInfo rep)]
op1_xparams, [Param (LParamInfo rep)]
op1_yparams) =
Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op1)) ([Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam1
([Param (LParamInfo rep)]
op2_xparams, [Param (LParamInfo rep)]
op2_yparams) =
Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op2)) ([Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam2
lam :: Lambda rep
lam =
Lambda :: forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda
{ lambdaParams :: [Param (LParamInfo rep)]
lambdaParams =
[Param (LParamInfo rep)]
op1_xparams
[Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op2_xparams
[Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op1_yparams
[Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op2_yparams,
lambdaReturnType :: [Type]
lambdaReturnType = Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam1 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam2,
lambdaBody :: Body rep
lambdaBody =
Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam1) Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam2)) (Result -> Body rep) -> Result -> Body rep
forall a b. (a -> b) -> a -> b
$
Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam1) Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam2)
}
in ( SegBinOp :: forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp
{ segBinOpComm :: Commutativity
segBinOpComm = SegBinOp rep -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm SegBinOp rep
op1 Commutativity -> Commutativity -> Commutativity
forall a. Semigroup a => a -> a -> a
<> SegBinOp rep -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm SegBinOp rep
op2,
segBinOpLambda :: Lambda rep
segBinOpLambda = Lambda rep
lam,
segBinOpNeutral :: [SubExp]
segBinOpNeutral = SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op2,
segBinOpShape :: Shape
segBinOpShape = SegBinOp rep -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp rep
op1
},
[a]
op1_aux [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
op2_aux
)
topDownSegOp SymbolTable rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ SegOp (SegOpLevel rep) rep
_ = Rule rep
forall rep. Rule rep
Skip
segOpGuts ::
SegOp (SegOpLevel rep) rep ->
( [Type],
KernelBody rep,
Int,
[Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
)
segOpGuts :: SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
[Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
segOpGuts (SegMap SegOpLevel rep
lvl SegSpace
space [Type]
kts KernelBody rep
body) =
([Type]
kts, KernelBody rep
body, Int
0, SegOpLevel rep
-> SegSpace
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space)
segOpGuts (SegScan SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
kts KernelBody rep
body) =
([Type]
kts, KernelBody rep
body, [SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops, SegOpLevel rep
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops)
segOpGuts (SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
kts KernelBody rep
body) =
([Type]
kts, KernelBody rep
body, [SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops, SegOpLevel rep
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops)
segOpGuts (SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops [Type]
kts KernelBody rep
body) =
([Type]
kts, KernelBody rep
body, [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (HistOp rep -> [VName]) -> HistOp rep -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp rep]
ops, SegOpLevel rep
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops)
bottomUpSegOp ::
(Aliased rep, HasSegOp rep, BuilderOps rep) =>
(ST.SymbolTable rep, UT.UsageTable) ->
Pat (LetDec rep) ->
StmAux (ExpDec rep) ->
SegOp (SegOpLevel rep) rep ->
Rule rep
bottomUpSegOp :: (SymbolTable rep, UsageTable)
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
bottomUpSegOp (SymbolTable rep
vtable, UsageTable
used) (Pat [PatElem (LetDec rep)]
kpes) StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
segop = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') <-
Scope rep
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep))
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall a b. (a -> b) -> a -> b
$
(([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stm rep
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep))
-> ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stms rep
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stm rep
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
distribute ([PatElem (LetDec rep)]
kpes, [Type]
kts, [KernelResult]
kres, Stms rep
forall a. Monoid a => a
mempty) Stms rep
kstms
Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
([PatElem (LetDec rep)]
kpes' [PatElem (LetDec rep)] -> [PatElem (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
kpes)
RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify
KernelBody rep
kbody' <-
Scope rep
-> RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep))
-> RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep)
forall a b. (a -> b) -> a -> b
$
Stms (Rep (RuleM rep))
-> [KernelResult] -> RuleM rep (KernelBody (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms rep
Stms (Rep (RuleM rep))
kstms' [KernelResult]
kres'
Stm (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (RuleM rep)) -> RuleM rep ())
-> Stm (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
kpes') StmAux (ExpDec rep)
dec (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (Op rep -> Exp rep) -> Op rep -> Exp rep
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> Op rep)
-> SegOp (SegOpLevel rep) rep -> Op rep
forall a b. (a -> b) -> a -> b
$ [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop [Type]
kts' KernelBody rep
kbody'
where
([Type]
kts, kbody :: KernelBody rep
kbody@(KernelBody BodyDec rep
_ Stms rep
kstms [KernelResult]
kres), Int
num_nonmap_results, [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop) =
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
[Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
forall rep.
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
[Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
segOpGuts SegOp (SegOpLevel rep) rep
segop
free_in_kstms :: Names
free_in_kstms = (Stm rep -> Names) -> Stms rep -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stms rep
kstms
consumed_in_segop :: Names
consumed_in_segop = KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
space :: SegSpace
space = SegOp (SegOpLevel rep) rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp (SegOpLevel rep) rep
segop
sliceWithGtidsFixed :: Stm rep -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm rep
stm
| Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (BasicOp (Index VName
arr Slice SubExp
slice)) <- Stm rep
stm,
[DimIndex SubExp]
space_slice <- ((VName, SubExp) -> DimIndex SubExp)
-> [(VName, SubExp)] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> ((VName, SubExp) -> SubExp)
-> (VName, SubExp)
-> DimIndex SubExp
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> SubExp
Var (VName -> SubExp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> SubExp
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [DimIndex SubExp])
-> [(VName, SubExp)] -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space,
[DimIndex SubExp]
space_slice [DimIndex SubExp] -> [DimIndex SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice,
Slice SubExp
remaining_slice <- [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ Int -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. Int -> [a] -> [a]
drop ([DimIndex SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
space_slice) (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice),
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Maybe (Entry rep) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry rep) -> Bool)
-> (VName -> Maybe (Entry rep)) -> VName -> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName -> SymbolTable rep -> Maybe (Entry rep))
-> SymbolTable rep -> VName -> Maybe (Entry rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup SymbolTable rep
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
VName -> Names
forall a. FreeIn a => a -> Names
freeIn VName
arr Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
remaining_slice =
(Slice SubExp, VName) -> Maybe (Slice SubExp, VName)
forall a. a -> Maybe a
Just (Slice SubExp
remaining_slice, VName
arr)
| Bool
otherwise =
Maybe (Slice SubExp, VName)
forall a. Maybe a
Nothing
distribute :: ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stm rep
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
distribute ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') Stm rep
stm
| Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ Exp rep
_ <- Stm rep
stm,
Just (Slice [DimIndex SubExp]
remaining_slice, VName
arr) <- Stm rep -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm rep
stm,
Just (PatElem (LetDec rep)
kpe, [PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> PatElem (LetDec rep)
-> Maybe
(PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
[KernelResult])
isResult [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' PatElem (LetDec rep)
pe = do
let outer_slice :: [DimIndex SubExp]
outer_slice =
(SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map
( \SubExp
d ->
SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice
(Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))
SubExp
d
(Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))
)
([SubExp] -> [DimIndex SubExp]) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
index :: PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe' =
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe'] (Exp rep -> RuleM rep ())
-> (Slice SubExp -> Exp rep) -> Slice SubExp -> RuleM rep ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep)
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> Exp rep
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> RuleM rep ()) -> Slice SubExp -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
[DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$
[DimIndex SubExp]
outer_slice [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. Semigroup a => a -> a -> a
<> [DimIndex SubExp]
remaining_slice
if PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe
VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
Bool -> Bool -> Bool
|| VName
arr
VName -> Names -> Bool
`nameIn` Names
consumed_in_segop
then do
VName
precopy <- String -> RuleM rep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM rep VName) -> String -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_precopy"
PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe {patElemName :: VName
patElemName = VName
precopy}
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
precopy
else PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe
([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( [PatElem (LetDec rep)]
kpes'',
[Type]
kts'',
[KernelResult]
kres'',
if PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> Names -> Bool
`nameIn` Names
free_in_kstms
then Stms rep
kstms' Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm
else Stms rep
kstms'
)
distribute ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') Stm rep
stm =
([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms' Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm)
isResult :: [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> PatElem (LetDec rep)
-> Maybe
(PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
[KernelResult])
isResult [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' PatElem (LetDec rep)
pe =
case ((PatElem (LetDec rep), Type, KernelResult) -> Bool)
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> ([(PatElem (LetDec rep), Type, KernelResult)],
[(PatElem (LetDec rep), Type, KernelResult)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (PatElem (LetDec rep), Type, KernelResult) -> Bool
matches ([(PatElem (LetDec rep), Type, KernelResult)]
-> ([(PatElem (LetDec rep), Type, KernelResult)],
[(PatElem (LetDec rep), Type, KernelResult)]))
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> ([(PatElem (LetDec rep), Type, KernelResult)],
[(PatElem (LetDec rep), Type, KernelResult)])
forall a b. (a -> b) -> a -> b
$ [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> [(PatElem (LetDec rep), Type, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' of
([(PatElem (LetDec rep)
kpe, Type
_, KernelResult
_)], [(PatElem (LetDec rep), Type, KernelResult)]
kpes_and_kres)
| Just Int
i <- PatElem (LetDec rep) -> [PatElem (LetDec rep)] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex PatElem (LetDec rep)
kpe [PatElem (LetDec rep)]
kpes,
Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
num_nonmap_results,
([PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [(PatElem (LetDec rep), Type, KernelResult)]
-> ([PatElem (LetDec rep)], [Type], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(PatElem (LetDec rep), Type, KernelResult)]
kpes_and_kres ->
(PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
[KernelResult])
-> Maybe
(PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
[KernelResult])
forall a. a -> Maybe a
Just (PatElem (LetDec rep)
kpe, [PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'')
([(PatElem (LetDec rep), Type, KernelResult)],
[(PatElem (LetDec rep), Type, KernelResult)])
_ -> Maybe
(PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
[KernelResult])
forall a. Maybe a
Nothing
where
matches :: (PatElem (LetDec rep), Type, KernelResult) -> Bool
matches (PatElem (LetDec rep)
_, Type
_, Returns ResultManifest
_ Certs
_ (Var VName
v)) = VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe
matches (PatElem (LetDec rep), Type, KernelResult)
_ = Bool
False
kernelBodyReturns ::
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep ->
[ExpReturns] ->
m [ExpReturns]
kernelBodyReturns :: KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns = (KernelResult -> ExpReturns -> m ExpReturns)
-> [KernelResult] -> [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM KernelResult -> ExpReturns -> m ExpReturns
forall (m :: * -> *) rep inner.
(Monad m, HasScope rep m, HasLetDecMem (LetDec rep), ASTRep rep,
OpReturns inner, FParamInfo rep ~ FParamMem,
LParamInfo rep ~ LParamMem, RetType rep ~ RetTypeMem,
BranchType rep ~ BranchTypeMem, Op rep ~ MemOp inner) =>
KernelResult -> ExpReturns -> m ExpReturns
correct ([KernelResult] -> [ExpReturns] -> m [ExpReturns])
-> (KernelBody somerep -> [KernelResult])
-> KernelBody somerep
-> [ExpReturns]
-> m [ExpReturns]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody somerep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult
where
correct :: KernelResult -> ExpReturns -> m ExpReturns
correct (WriteReturns Certs
_ Shape
_ VName
arr [(Slice SubExp, SubExp)]
_) ExpReturns
_ = VName -> m ExpReturns
forall rep (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns VName
arr
correct KernelResult
_ ExpReturns
ret = ExpReturns -> m ExpReturns
forall (f :: * -> *) a. Applicative f => a -> f a
pure ExpReturns
ret
segOpReturns ::
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl somerep ->
m [ExpReturns]
segOpReturns :: SegOp lvl somerep -> m [ExpReturns]
segOpReturns k :: SegOp lvl somerep
k@(SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody somerep
kbody) =
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
forall rep inner (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody somerep
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl somerep -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl somerep
k
segOpReturns k :: SegOp lvl somerep
k@(SegRed lvl
_ SegSpace
_ [SegBinOp somerep]
_ [Type]
_ KernelBody somerep
kbody) =
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
forall rep inner (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody somerep
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl somerep -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl somerep
k
segOpReturns k :: SegOp lvl somerep
k@(SegScan lvl
_ SegSpace
_ [SegBinOp somerep]
_ [Type]
_ KernelBody somerep
kbody) =
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
forall rep inner (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody somerep
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl somerep -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl somerep
k
segOpReturns (SegHist lvl
_ SegSpace
_ [HistOp somerep]
ops [Type]
_ KernelBody somerep
_) =
[[ExpReturns]] -> [ExpReturns]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[ExpReturns]] -> [ExpReturns])
-> m [[ExpReturns]] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp somerep -> m [ExpReturns])
-> [HistOp somerep] -> m [[ExpReturns]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VName -> m ExpReturns) -> [VName] -> m [ExpReturns]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m ExpReturns
forall rep (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns ([VName] -> m [ExpReturns])
-> (HistOp somerep -> [VName]) -> HistOp somerep -> m [ExpReturns]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp somerep -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp somerep]
ops