{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Segmented operations.  These correspond to perfect @map@ nests on
-- top of /something/, except that the @map@s are conceptually only
-- over @iota@s (so there will be explicit indexing inside them).
module Futhark.IR.SegOp
  ( SegOp (..),
    segLevel,
    segBody,
    segSpace,
    typeCheckSegOp,
    SegSpace (..),
    scopeOfSegSpace,
    segSpaceDims,

    -- * Details
    HistOp (..),
    histType,
    splitHistResults,
    SegBinOp (..),
    segBinOpResults,
    segBinOpChunks,
    KernelBody (..),
    aliasAnalyseKernelBody,
    consumedInKernelBody,
    ResultManifest (..),
    KernelResult (..),
    kernelResultCerts,
    kernelResultSubExp,

    -- ** Generic traversal
    SegOpMapper (..),
    identitySegOpMapper,
    mapSegOpM,
    traverseSegOpStms,

    -- * Simplification
    simplifySegOp,
    HasSegOp (..),
    segOpRules,

    -- * Memory
    segOpReturns,
  )
where

import Control.Category
import Control.Monad
import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State.Strict
import Control.Monad.Writer
import Data.Bifunctor (first)
import Data.Bitraversable
import Data.Foldable (traverse_)
import Data.List
  ( elemIndex,
    foldl',
    groupBy,
    intersperse,
    isPrefixOf,
    partition,
  )
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR
import Futhark.IR.Aliases
  ( Aliases,
    CanBeAliased (..),
  )
import Futhark.IR.Mem
import Futhark.IR.Prop.Aliases
import Futhark.IR.TypeCheck qualified as TC
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (chunks, maybeNth)
import Futhark.Util.Pretty
  ( Doc,
    apply,
    hsep,
    parens,
    ppTuple',
    pretty,
    (<+>),
    (</>),
  )
import Futhark.Util.Pretty qualified as PP
import Prelude hiding (id, (.))

-- | An operator for 'SegHist'.
data HistOp rep = HistOp
  { forall rep. HistOp rep -> ShapeBase SubExp
histShape :: Shape,
    forall rep. HistOp rep -> SubExp
histRaceFactor :: SubExp,
    forall rep. HistOp rep -> [VName]
histDest :: [VName],
    forall rep. HistOp rep -> [SubExp]
histNeutral :: [SubExp],
    -- | In case this operator is semantically a vectorised
    -- operator (corresponding to a perfect map nest in the
    -- SOACS representation), these are the logical
    -- "dimensions".  This is used to generate more efficient
    -- code.
    forall rep. HistOp rep -> ShapeBase SubExp
histOpShape :: Shape,
    forall rep. HistOp rep -> Lambda rep
histOp :: Lambda rep
  }
  deriving (HistOp rep -> HistOp rep -> Bool
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HistOp rep -> HistOp rep -> Bool
$c/= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
== :: HistOp rep -> HistOp rep -> Bool
$c== :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
Eq, HistOp rep -> HistOp rep -> Bool
HistOp rep -> HistOp rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
min :: HistOp rep -> HistOp rep -> HistOp rep
$cmin :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
max :: HistOp rep -> HistOp rep -> HistOp rep
$cmax :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
>= :: HistOp rep -> HistOp rep -> Bool
$c>= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
> :: HistOp rep -> HistOp rep -> Bool
$c> :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
<= :: HistOp rep -> HistOp rep -> Bool
$c<= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
< :: HistOp rep -> HistOp rep -> Bool
$c< :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
compare :: HistOp rep -> HistOp rep -> Ordering
$ccompare :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
Ord, Int -> HistOp rep -> ShowS
forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
forall rep. RepTypes rep => [HistOp rep] -> ShowS
forall rep. RepTypes rep => HistOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HistOp rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [HistOp rep] -> ShowS
show :: HistOp rep -> String
$cshow :: forall rep. RepTypes rep => HistOp rep -> String
showsPrec :: Int -> HistOp rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
Show)

-- | The type of a histogram produced by a 'HistOp'.  This can be
-- different from the type of the 'histDest's in case we are
-- dealing with a segmented histogram.
histType :: HistOp rep -> [Type]
histType :: forall rep. HistOp rep -> [Type]
histType HistOp rep
op =
  forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` (forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp rep
op forall a. Semigroup a => a -> a -> a
<> forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp rep
op)) forall a b. (a -> b) -> a -> b
$
    forall rep. Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$
      forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op

-- | Split reduction results returned by a 'KernelBody' into those
-- that correspond to indexes for the 'HistOp's, and those that
-- correspond to value.
splitHistResults :: [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults :: forall rep. [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp rep]
ops [SubExp]
res =
  let ranks :: [Int]
ranks = forall a b. (a -> b) -> [a] -> [b]
map (forall a. ArrayShape a => a -> Int
shapeRank forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. HistOp rep -> ShapeBase SubExp
histShape) [HistOp rep]
ops
      ([SubExp]
idxs, [SubExp]
vals) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
ranks) [SubExp]
res
   in forall a b. [a] -> [b] -> [(a, b)]
zip
        (forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ranks [SubExp]
idxs)
        (forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. HistOp rep -> [VName]
histDest) [HistOp rep]
ops) [SubExp]
vals)

-- | An operator for 'SegScan' and 'SegRed'.
data SegBinOp rep = SegBinOp
  { forall rep. SegBinOp rep -> Commutativity
segBinOpComm :: Commutativity,
    forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda :: Lambda rep,
    forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral :: [SubExp],
    -- | In case this operator is semantically a vectorised
    -- operator (corresponding to a perfect map nest in the
    -- SOACS representation), these are the logical
    -- "dimensions".  This is used to generate more efficient
    -- code.
    forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape :: Shape
  }
  deriving (SegBinOp rep -> SegBinOp rep -> Bool
forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegBinOp rep -> SegBinOp rep -> Bool
$c/= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
== :: SegBinOp rep -> SegBinOp rep -> Bool
$c== :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
Eq, SegBinOp rep -> SegBinOp rep -> Bool
SegBinOp rep -> SegBinOp rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (SegBinOp rep)
forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Ordering
forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
min :: SegBinOp rep -> SegBinOp rep -> SegBinOp rep
$cmin :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
max :: SegBinOp rep -> SegBinOp rep -> SegBinOp rep
$cmax :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
>= :: SegBinOp rep -> SegBinOp rep -> Bool
$c>= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
> :: SegBinOp rep -> SegBinOp rep -> Bool
$c> :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
<= :: SegBinOp rep -> SegBinOp rep -> Bool
$c<= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
< :: SegBinOp rep -> SegBinOp rep -> Bool
$c< :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
compare :: SegBinOp rep -> SegBinOp rep -> Ordering
$ccompare :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Ordering
Ord, Int -> SegBinOp rep -> ShowS
forall rep. RepTypes rep => Int -> SegBinOp rep -> ShowS
forall rep. RepTypes rep => [SegBinOp rep] -> ShowS
forall rep. RepTypes rep => SegBinOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegBinOp rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [SegBinOp rep] -> ShowS
show :: SegBinOp rep -> String
$cshow :: forall rep. RepTypes rep => SegBinOp rep -> String
showsPrec :: Int -> SegBinOp rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> SegBinOp rep -> ShowS
Show)

-- | How many reduction results are produced by these 'SegBinOp's?
segBinOpResults :: [SegBinOp rep] -> Int
segBinOpResults :: forall rep. [SegBinOp rep] -> Int
segBinOpResults = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral)

-- | Split some list into chunks equal to the number of values
-- returned by each 'SegBinOp'
segBinOpChunks :: [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks :: forall rep a. [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks = forall a. [Int] -> [a] -> [[a]]
chunks forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral)

-- | The body of a 'SegOp'.
data KernelBody rep = KernelBody
  { forall rep. KernelBody rep -> BodyDec rep
kernelBodyDec :: BodyDec rep,
    forall rep. KernelBody rep -> Stms rep
kernelBodyStms :: Stms rep,
    forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult :: [KernelResult]
  }

deriving instance RepTypes rep => Ord (KernelBody rep)

deriving instance RepTypes rep => Show (KernelBody rep)

deriving instance RepTypes rep => Eq (KernelBody rep)

-- | Metadata about whether there is a subtle point to this
-- 'KernelResult'.  This is used to protect things like tiling, which
-- might otherwise be removed by the simplifier because they're
-- semantically redundant.  This has no semantic effect and can be
-- ignored at code generation.
data ResultManifest
  = -- | Don't simplify this one!
    ResultNoSimplify
  | -- | Go nuts.
    ResultMaySimplify
  | -- | The results produced are only used within the
    -- same physical thread later on, and can thus be
    -- kept in registers.
    ResultPrivate
  deriving (ResultManifest -> ResultManifest -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ResultManifest -> ResultManifest -> Bool
$c/= :: ResultManifest -> ResultManifest -> Bool
== :: ResultManifest -> ResultManifest -> Bool
$c== :: ResultManifest -> ResultManifest -> Bool
Eq, Int -> ResultManifest -> ShowS
[ResultManifest] -> ShowS
ResultManifest -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ResultManifest] -> ShowS
$cshowList :: [ResultManifest] -> ShowS
show :: ResultManifest -> String
$cshow :: ResultManifest -> String
showsPrec :: Int -> ResultManifest -> ShowS
$cshowsPrec :: Int -> ResultManifest -> ShowS
Show, Eq ResultManifest
ResultManifest -> ResultManifest -> Bool
ResultManifest -> ResultManifest -> Ordering
ResultManifest -> ResultManifest -> ResultManifest
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ResultManifest -> ResultManifest -> ResultManifest
$cmin :: ResultManifest -> ResultManifest -> ResultManifest
max :: ResultManifest -> ResultManifest -> ResultManifest
$cmax :: ResultManifest -> ResultManifest -> ResultManifest
>= :: ResultManifest -> ResultManifest -> Bool
$c>= :: ResultManifest -> ResultManifest -> Bool
> :: ResultManifest -> ResultManifest -> Bool
$c> :: ResultManifest -> ResultManifest -> Bool
<= :: ResultManifest -> ResultManifest -> Bool
$c<= :: ResultManifest -> ResultManifest -> Bool
< :: ResultManifest -> ResultManifest -> Bool
$c< :: ResultManifest -> ResultManifest -> Bool
compare :: ResultManifest -> ResultManifest -> Ordering
$ccompare :: ResultManifest -> ResultManifest -> Ordering
Ord)

-- | A 'KernelBody' does not return an ordinary 'Result'.  Instead, it
-- returns a list of these.
data KernelResult
  = -- | Each "worker" in the kernel returns this.
    -- Whether this is a result-per-thread or a
    -- result-per-group depends on where the 'SegOp' occurs.
    Returns ResultManifest Certs SubExp
  | WriteReturns
      Certs
      Shape -- Size of array.  Must match number of dims.
      VName -- Which array
      [(Slice SubExp, SubExp)]
  | TileReturns
      Certs
      [(SubExp, SubExp)] -- Total/tile for each dimension
      VName -- Tile written by this worker.
      -- The TileReturns must not expect more than one
      -- result to be written per physical thread.
  | RegTileReturns
      Certs
      -- For each dim of result:
      [ ( SubExp, -- size of this dim.
          SubExp, -- block tile size for this dim.
          SubExp -- reg tile size for this dim.
        )
      ]
      VName -- Tile returned by this worker/group.
  deriving (KernelResult -> KernelResult -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelResult -> KernelResult -> Bool
$c/= :: KernelResult -> KernelResult -> Bool
== :: KernelResult -> KernelResult -> Bool
$c== :: KernelResult -> KernelResult -> Bool
Eq, Int -> KernelResult -> ShowS
[KernelResult] -> ShowS
KernelResult -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelResult] -> ShowS
$cshowList :: [KernelResult] -> ShowS
show :: KernelResult -> String
$cshow :: KernelResult -> String
showsPrec :: Int -> KernelResult -> ShowS
$cshowsPrec :: Int -> KernelResult -> ShowS
Show, Eq KernelResult
KernelResult -> KernelResult -> Bool
KernelResult -> KernelResult -> Ordering
KernelResult -> KernelResult -> KernelResult
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelResult -> KernelResult -> KernelResult
$cmin :: KernelResult -> KernelResult -> KernelResult
max :: KernelResult -> KernelResult -> KernelResult
$cmax :: KernelResult -> KernelResult -> KernelResult
>= :: KernelResult -> KernelResult -> Bool
$c>= :: KernelResult -> KernelResult -> Bool
> :: KernelResult -> KernelResult -> Bool
$c> :: KernelResult -> KernelResult -> Bool
<= :: KernelResult -> KernelResult -> Bool
$c<= :: KernelResult -> KernelResult -> Bool
< :: KernelResult -> KernelResult -> Bool
$c< :: KernelResult -> KernelResult -> Bool
compare :: KernelResult -> KernelResult -> Ordering
$ccompare :: KernelResult -> KernelResult -> Ordering
Ord)

-- | Get the certs for this 'KernelResult'.
kernelResultCerts :: KernelResult -> Certs
kernelResultCerts :: KernelResult -> Certs
kernelResultCerts (Returns ResultManifest
_ Certs
cs SubExp
_) = Certs
cs
kernelResultCerts (WriteReturns Certs
cs ShapeBase SubExp
_ VName
_ [(Slice SubExp, SubExp)]
_) = Certs
cs
kernelResultCerts (TileReturns Certs
cs [(SubExp, SubExp)]
_ VName
_) = Certs
cs
kernelResultCerts (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
_ VName
_) = Certs
cs

-- | Get the root t'SubExp' corresponding values for a 'KernelResult'.
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp (Returns ResultManifest
_ Certs
_ SubExp
se) = SubExp
se
kernelResultSubExp (WriteReturns Certs
_ ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) = VName -> SubExp
Var VName
arr
kernelResultSubExp (TileReturns Certs
_ [(SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v
kernelResultSubExp (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v

instance FreeIn KernelResult where
  freeIn' :: KernelResult -> FV
freeIn' (Returns ResultManifest
_ Certs
cs SubExp
what) = forall a. FreeIn a => a -> FV
freeIn' Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' SubExp
what
  freeIn' (WriteReturns Certs
cs ShapeBase SubExp
rws VName
arr [(Slice SubExp, SubExp)]
res) = forall a. FreeIn a => a -> FV
freeIn' Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' ShapeBase SubExp
rws forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' VName
arr forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [(Slice SubExp, SubExp)]
res
  freeIn' (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
    forall a. FreeIn a => a -> FV
freeIn' Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp)]
dims forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' VName
v
  freeIn' (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
    forall a. FreeIn a => a -> FV
freeIn' Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp, SubExp)]
dims_n_tiles forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' VName
v

instance ASTRep rep => FreeIn (KernelBody rep) where
  freeIn' :: KernelBody rep -> FV
freeIn' (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
    Names -> FV -> FV
fvBind Names
bound_in_stms forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> FV
freeIn' BodyDec rep
dec forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' Stms rep
stms forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [KernelResult]
res
    where
      bound_in_stms :: Names
bound_in_stms = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall rep. Stm rep -> Names
boundByStm Stms rep
stms

instance ASTRep rep => Substitute (KernelBody rep) where
  substituteNames :: Map VName VName -> KernelBody rep -> KernelBody rep
substituteNames Map VName VName
subst (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
    forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst BodyDec rep
dec)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Stms rep
stms)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [KernelResult]
res)

instance Substitute KernelResult where
  substituteNames :: Map VName VName -> KernelResult -> KernelResult
substituteNames Map VName VName
subst (Returns ResultManifest
manifest Certs
cs SubExp
se) =
    ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs) (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
se)
  substituteNames Map VName VName
subst (WriteReturns Certs
cs ShapeBase SubExp
rws VName
arr [(Slice SubExp, SubExp)]
res) =
    Certs
-> ShapeBase SubExp
-> VName
-> [(Slice SubExp, SubExp)]
-> KernelResult
WriteReturns
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ShapeBase SubExp
rws)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
arr)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(Slice SubExp, SubExp)]
res)
  substituteNames Map VName VName
subst (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
    Certs -> [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp)]
dims)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
  substituteNames Map VName VName
subst (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
    Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp, SubExp)]
dims_n_tiles)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)

instance ASTRep rep => Rename (KernelBody rep) where
  rename :: KernelBody rep -> RenameM (KernelBody rep)
rename (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) = do
    BodyDec rep
dec' <- forall a. Rename a => a -> RenameM a
rename BodyDec rep
dec
    forall rep a.
Renameable rep =>
Stms rep -> (Stms rep -> RenameM a) -> RenameM a
renamingStms Stms rep
stms forall a b. (a -> b) -> a -> b
$ \Stms rep
stms' ->
      forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec' Stms rep
stms' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename [KernelResult]
res

instance Rename KernelResult where
  rename :: KernelResult -> RenameM KernelResult
rename = forall a. Substitute a => a -> RenameM a
substituteRename

-- | Perform alias analysis on a 'KernelBody'.
aliasAnalyseKernelBody ::
  Alias.AliasableRep rep =>
  AliasTable ->
  KernelBody rep ->
  KernelBody (Aliases rep)
aliasAnalyseKernelBody :: forall rep.
AliasableRep rep =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
aliases (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
  let Body BodyDec (Aliases rep)
dec' Stms (Aliases rep)
stms' Result
_ = forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
Alias.analyseBody AliasTable
aliases forall a b. (a -> b) -> a -> b
$ forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms []
   in forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Aliases rep)
dec' Stms (Aliases rep)
stms' [KernelResult]
res

-- | The variables consumed in the kernel body.
consumedInKernelBody ::
  Aliased rep =>
  KernelBody rep ->
  Names
consumedInKernelBody :: forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
  forall rep. Aliased rep => Body rep -> Names
consumedInBody (forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms []) forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Names
consumedByReturn [KernelResult]
res)
  where
    consumedByReturn :: KernelResult -> Names
consumedByReturn (WriteReturns Certs
_ ShapeBase SubExp
_ VName
a [(Slice SubExp, SubExp)]
_) = VName -> Names
oneName VName
a
    consumedByReturn KernelResult
_ = forall a. Monoid a => a
mempty

checkKernelBody ::
  TC.Checkable rep =>
  [Type] ->
  KernelBody (Aliases rep) ->
  TC.TypeM rep ()
checkKernelBody :: forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts (KernelBody (BodyAliasing
_, BodyDec rep
dec) Stms (Aliases rep)
stms [KernelResult]
kres) = do
  forall rep. Checkable rep => BodyDec rep -> TypeM rep ()
TC.checkBodyDec BodyDec rep
dec
  -- We consume the kernel results (when applicable) before
  -- type-checking the stms, so we will get an error if a statement
  -- uses an array that is written to in a result.
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {rep}. Checkable rep => KernelResult -> TypeM rep ()
consumeKernelResult [KernelResult]
kres
  forall rep a.
Checkable rep =>
Stms (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.checkStms Stms (Aliases rep)
stms forall a b. (a -> b) -> a -> b
$ do
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres) forall a b. (a -> b) -> a -> b
$
      forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
        Text
"Kernel return type is "
          forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
ts
          forall a. Semigroup a => a -> a -> a
<> Text
", but body returns "
          forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText (forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres)
          forall a. Semigroup a => a -> a -> a
<> Text
" values."
    forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {rep}. Checkable rep => KernelResult -> Type -> TypeM rep ()
checkKernelResult [KernelResult]
kres [Type]
ts
  where
    consumeKernelResult :: KernelResult -> TypeM rep ()
consumeKernelResult (WriteReturns Certs
_ ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) =
      forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
arr
    consumeKernelResult KernelResult
_ =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    checkKernelResult :: KernelResult -> Type -> TypeM rep ()
checkKernelResult (Returns ResultManifest
_ Certs
cs SubExp
what) Type
t = do
      forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
      forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [Type
t] SubExp
what
    checkKernelResult (WriteReturns Certs
cs ShapeBase SubExp
shape VName
arr [(Slice SubExp, SubExp)]
res) Type
t = do
      forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
      Type
arr_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Slice SubExp, SubExp)]
res forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
e) -> do
        forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Slice SubExp
slice
        forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [Type
t] SubExp
e
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t forall a. Eq a => a -> a -> Bool
== Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) forall a b. (a -> b) -> a -> b
$
          forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
            forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
              Text
"WriteReturns returning "
                forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText SubExp
e
                forall a. Semigroup a => a -> a -> a
<> Text
" of type "
                forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText Type
t
                forall a. Semigroup a => a -> a -> a
<> Text
", shape="
                forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText ShapeBase SubExp
shape
                forall a. Semigroup a => a -> a -> a
<> Text
", but destination array has type "
                forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText Type
arr_t
    checkKernelResult (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) Type
t = do
      forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(SubExp, SubExp)]
dims forall a b. (a -> b) -> a -> b
$ \(SubExp
dim, SubExp
tile) -> do
        forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
dim
        forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
tile
      Type
vt <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt forall a. Eq a => a -> a -> Bool
== Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` forall d. [d] -> ShapeBase d
Shape (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims)) forall a b. (a -> b) -> a -> b
$
        forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
          forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
            Text
"Invalid type for TileReturns " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText VName
v
    checkKernelResult (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
arr) Type
t = do
      forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
dims
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
blk_tiles
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
reg_tiles

      -- assert that arr is of element type t and shape (rev outer_tiles ++ reg_tiles)
      Type
arr_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t forall a. Eq a => a -> a -> Bool
== Type
expected) forall a b. (a -> b) -> a -> b
$
        forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
          Text
"Invalid type for TileReturns. Expected:\n  "
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText Type
expected
            forall a. Semigroup a => a -> a -> a
<> Text
",\ngot:\n  "
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText Type
arr_t
      where
        ([SubExp]
dims, [SubExp]
blk_tiles, [SubExp]
reg_tiles) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, SubExp)]
dims_n_tiles
        expected :: Type
expected = Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` forall d. [d] -> ShapeBase d
Shape ([SubExp]
blk_tiles forall a. Semigroup a => a -> a -> a
<> [SubExp]
reg_tiles)

kernelBodyMetrics :: OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics :: forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall rep. OpMetrics (Op rep) => Stm rep -> MetricsM ()
stmMetrics forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. KernelBody rep -> Stms rep
kernelBodyStms

instance PrettyRep rep => Pretty (KernelBody rep) where
  pretty :: forall ann. KernelBody rep -> Doc ann
pretty (KernelBody BodyDec rep
_ Stms rep
stms [KernelResult]
res) =
    forall a. [Doc a] -> Doc a
PP.stack (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty (forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms))
      forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"return"
      forall a. Doc a -> Doc a -> Doc a
<+> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
PP.commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [KernelResult]
res)

certAnnots :: Certs -> [Doc ann]
certAnnots :: forall ann. Certs -> [Doc ann]
certAnnots Certs
cs
  | Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty = []
  | Bool
otherwise = [forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs]

instance Pretty KernelResult where
  pretty :: forall ann. KernelResult -> Doc ann
pretty (Returns ResultManifest
ResultNoSimplify Certs
cs SubExp
what) =
    forall a. [Doc a] -> Doc a
hsep forall a b. (a -> b) -> a -> b
$ forall ann. Certs -> [Doc ann]
certAnnots Certs
cs forall a. Semigroup a => a -> a -> a
<> [Doc ann
"returns (manifest)" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
what]
  pretty (Returns ResultManifest
ResultPrivate Certs
cs SubExp
what) =
    forall a. [Doc a] -> Doc a
hsep forall a b. (a -> b) -> a -> b
$ forall ann. Certs -> [Doc ann]
certAnnots Certs
cs forall a. Semigroup a => a -> a -> a
<> [Doc ann
"returns (private)" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
what]
  pretty (Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
what) =
    forall a. [Doc a] -> Doc a
hsep forall a b. (a -> b) -> a -> b
$ forall ann. Certs -> [Doc ann]
certAnnots Certs
cs forall a. Semigroup a => a -> a -> a
<> [Doc ann
"returns" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
what]
  pretty (WriteReturns Certs
cs ShapeBase SubExp
shape VName
arr [(Slice SubExp, SubExp)]
res) =
    forall a. [Doc a] -> Doc a
hsep forall a b. (a -> b) -> a -> b
$
      forall ann. Certs -> [Doc ann]
certAnnots Certs
cs
        forall a. Semigroup a => a -> a -> a
<> [ forall a ann. Pretty a => a -> Doc ann
pretty VName
arr
               forall a. Doc a -> Doc a -> Doc a
<+> forall ann. Doc ann
PP.colon
               forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty ShapeBase SubExp
shape
               forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"with"
               forall a. Doc a -> Doc a -> Doc a
<+> forall a. [Doc a] -> Doc a
PP.apply (forall a b. (a -> b) -> [a] -> [b]
map forall {a} {a} {ann}. (Pretty a, Pretty a) => (a, a) -> Doc ann
ppRes [(Slice SubExp, SubExp)]
res)
           ]
    where
      ppRes :: (a, a) -> Doc ann
ppRes (a
slice, a
e) = forall a ann. Pretty a => a -> Doc ann
pretty a
slice forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"=" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty a
e
  pretty (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
    forall a. [Doc a] -> Doc a
hsep forall a b. (a -> b) -> a -> b
$ forall ann. Certs -> [Doc ann]
certAnnots Certs
cs forall a. Semigroup a => a -> a -> a
<> [Doc ann
"tile" forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
apply (forall a b. (a -> b) -> [a] -> [b]
map forall {a} {a} {ann}. (Pretty a, Pretty a) => (a, a) -> Doc ann
onDim [(SubExp, SubExp)]
dims) forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty VName
v]
    where
      onDim :: (a, a) -> Doc ann
onDim (a
dim, a
tile) = forall a ann. Pretty a => a -> Doc ann
pretty a
dim forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"/" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty a
tile
  pretty (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
    forall a. [Doc a] -> Doc a
hsep forall a b. (a -> b) -> a -> b
$ forall ann. Certs -> [Doc ann]
certAnnots Certs
cs forall a. Semigroup a => a -> a -> a
<> [Doc ann
"blkreg_tile" forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
apply (forall a b. (a -> b) -> [a] -> [b]
map forall {a} {a} {a} {ann}.
(Pretty a, Pretty a, Pretty a) =>
(a, a, a) -> Doc ann
onDim [(SubExp, SubExp, SubExp)]
dims_n_tiles) forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty VName
v]
    where
      onDim :: (a, a, a) -> Doc ann
onDim (a
dim, a
blk_tile, a
reg_tile) =
        forall a ann. Pretty a => a -> Doc ann
pretty a
dim forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"/" forall a. Doc a -> Doc a -> Doc a
<+> forall ann. Doc ann -> Doc ann
parens (forall a ann. Pretty a => a -> Doc ann
pretty a
blk_tile forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"*" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty a
reg_tile)

-- | Index space of a 'SegOp'.
data SegSpace = SegSpace
  { -- | Flat physical index corresponding to the
    -- dimensions (at code generation used for a
    -- thread ID or similar).
    SegSpace -> VName
segFlat :: VName,
    SegSpace -> [(VName, SubExp)]
unSegSpace :: [(VName, SubExp)]
  }
  deriving (SegSpace -> SegSpace -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegSpace -> SegSpace -> Bool
$c/= :: SegSpace -> SegSpace -> Bool
== :: SegSpace -> SegSpace -> Bool
$c== :: SegSpace -> SegSpace -> Bool
Eq, Eq SegSpace
SegSpace -> SegSpace -> Bool
SegSpace -> SegSpace -> Ordering
SegSpace -> SegSpace -> SegSpace
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegSpace -> SegSpace -> SegSpace
$cmin :: SegSpace -> SegSpace -> SegSpace
max :: SegSpace -> SegSpace -> SegSpace
$cmax :: SegSpace -> SegSpace -> SegSpace
>= :: SegSpace -> SegSpace -> Bool
$c>= :: SegSpace -> SegSpace -> Bool
> :: SegSpace -> SegSpace -> Bool
$c> :: SegSpace -> SegSpace -> Bool
<= :: SegSpace -> SegSpace -> Bool
$c<= :: SegSpace -> SegSpace -> Bool
< :: SegSpace -> SegSpace -> Bool
$c< :: SegSpace -> SegSpace -> Bool
compare :: SegSpace -> SegSpace -> Ordering
$ccompare :: SegSpace -> SegSpace -> Ordering
Ord, Int -> SegSpace -> ShowS
[SegSpace] -> ShowS
SegSpace -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegSpace] -> ShowS
$cshowList :: [SegSpace] -> ShowS
show :: SegSpace -> String
$cshow :: SegSpace -> String
showsPrec :: Int -> SegSpace -> ShowS
$cshowsPrec :: Int -> SegSpace -> ShowS
Show)

-- | The sizes spanned by the indexes of the 'SegSpace'.
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims (SegSpace VName
_ [(VName, SubExp)]
space) = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
space

-- | A 'Scope' containing all the identifiers brought into scope by
-- this 'SegSpace'.
scopeOfSegSpace :: SegSpace -> Scope rep
scopeOfSegSpace :: forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegSpace VName
phys [(VName, SubExp)]
space) =
  forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (VName
phys forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, SubExp)]
space) forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat forall a b. (a -> b) -> a -> b
$ forall rep. IntType -> NameInfo rep
IndexName IntType
Int64

checkSegSpace :: TC.Checkable rep => SegSpace -> TC.TypeM rep ()
checkSegSpace :: forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace (SegSpace VName
_ [(VName, SubExp)]
dims) =
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> b
snd) [(VName, SubExp)]
dims

-- | A 'SegOp' is semantically a perfectly nested stack of maps, on
-- top of some bottommost computation (scalar computation, reduction,
-- scan, or histogram).  The 'SegSpace' encodes the original map
-- structure.
--
-- All 'SegOp's are parameterised by the representation of their body,
-- as well as a *level*.  The *level* is a representation-specific bit
-- of information.  For example, in GPU backends, it is used to
-- indicate whether the 'SegOp' is expected to run at the thread-level
-- or the group-level.
data SegOp lvl rep
  = SegMap lvl SegSpace [Type] (KernelBody rep)
  | -- | The KernelSpace must always have at least two dimensions,
    -- implying that the result of a SegRed is always an array.
    SegRed lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep)
  | SegScan lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep)
  | SegHist lvl SegSpace [HistOp rep] [Type] (KernelBody rep)
  deriving (SegOp lvl rep -> SegOp lvl rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
/= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c/= :: forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
== :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c== :: forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
Eq, SegOp lvl rep -> SegOp lvl rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {lvl} {rep}. (RepTypes rep, Ord lvl) => Eq (SegOp lvl rep)
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Ordering
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
min :: SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
$cmin :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
max :: SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
$cmax :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
>= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c>= :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
> :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c> :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
<= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c<= :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
< :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c< :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
compare :: SegOp lvl rep -> SegOp lvl rep -> Ordering
$ccompare :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Ordering
Ord, Int -> SegOp lvl rep -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall lvl rep.
(RepTypes rep, Show lvl) =>
Int -> SegOp lvl rep -> ShowS
forall lvl rep.
(RepTypes rep, Show lvl) =>
[SegOp lvl rep] -> ShowS
forall lvl rep. (RepTypes rep, Show lvl) => SegOp lvl rep -> String
showList :: [SegOp lvl rep] -> ShowS
$cshowList :: forall lvl rep.
(RepTypes rep, Show lvl) =>
[SegOp lvl rep] -> ShowS
show :: SegOp lvl rep -> String
$cshow :: forall lvl rep. (RepTypes rep, Show lvl) => SegOp lvl rep -> String
showsPrec :: Int -> SegOp lvl rep -> ShowS
$cshowsPrec :: forall lvl rep.
(RepTypes rep, Show lvl) =>
Int -> SegOp lvl rep -> ShowS
Show)

-- | The level of a 'SegOp'.
segLevel :: SegOp lvl rep -> lvl
segLevel :: forall lvl rep. SegOp lvl rep -> lvl
segLevel (SegMap lvl
lvl SegSpace
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegRed lvl
lvl SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegScan lvl
lvl SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegHist lvl
lvl SegSpace
_ [HistOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl

-- | The space of a 'SegOp'.
segSpace :: SegOp lvl rep -> SegSpace
segSpace :: forall lvl rep. SegOp lvl rep -> SegSpace
segSpace (SegMap lvl
_ SegSpace
lvl [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegRed lvl
_ SegSpace
lvl [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegScan lvl
_ SegSpace
lvl [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegHist lvl
_ SegSpace
lvl [HistOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl

-- | The body of a 'SegOp'.
segBody :: SegOp lvl rep -> KernelBody rep
segBody :: forall lvl rep. SegOp lvl rep -> KernelBody rep
segBody SegOp lvl rep
segop =
  case SegOp lvl rep
segop of
    SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
    SegRed lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
    SegScan lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
    SegHist lvl
_ SegSpace
_ [HistOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body

segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
_ Type
t (WriteReturns Certs
_ ShapeBase SubExp
shape VName
_ [(Slice SubExp, SubExp)]
_) =
  Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape
segResultShape SegSpace
space Type
t Returns {} =
  forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow) Type
t forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segResultShape SegSpace
_ Type
t (TileReturns Certs
_ [(SubExp, SubExp)]
dims VName
_) =
  Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` forall d. [d] -> ShapeBase d
Shape (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims)
segResultShape SegSpace
_ Type
t (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
_) =
  Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` forall d. [d] -> ShapeBase d
Shape (forall a b. (a -> b) -> [a] -> [b]
map (\(SubExp
dim, SubExp
_, SubExp
_) -> SubExp
dim) [(SubExp, SubExp, SubExp)]
dims_n_tiles)

-- | The return type of a 'SegOp'.
segOpType :: SegOp lvl rep -> [Type]
segOpType :: forall lvl rep. SegOp lvl rep -> [Type]
segOpType (SegMap lvl
_ SegSpace
space [Type]
ts KernelBody rep
kbody) =
  forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
ts forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
segOpType (SegRed lvl
_ SegSpace
space [SegBinOp rep]
reds [Type]
ts KernelBody rep
kbody) =
  [Type]
red_ts
    forall a. [a] -> [a] -> [a]
++ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
      (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
      [Type]
map_ts
      (forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody)
  where
    map_ts :: [Type]
map_ts = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) [Type]
ts
    segment_dims :: [SubExp]
segment_dims = forall a. [a] -> [a]
init forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
    red_ts :: [Type]
red_ts = do
      SegBinOp rep
op <- [SegBinOp rep]
reds
      let shape :: ShapeBase SubExp
shape = forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims forall a. Semigroup a => a -> a -> a
<> forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op
      forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (forall rep. Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op)
segOpType (SegScan lvl
_ SegSpace
space [SegBinOp rep]
scans [Type]
ts KernelBody rep
kbody) =
  [Type]
scan_ts
    forall a. [a] -> [a] -> [a]
++ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
      (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
      [Type]
map_ts
      (forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody)
  where
    map_ts :: [Type]
map_ts = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) [Type]
ts
    scan_ts :: [Type]
scan_ts = do
      SegBinOp rep
op <- [SegBinOp rep]
scans
      let shape :: ShapeBase SubExp
shape = forall d. [d] -> ShapeBase d
Shape (SegSpace -> [SubExp]
segSpaceDims SegSpace
space) forall a. Semigroup a => a -> a -> a
<> forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op
      forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (forall rep. Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op)
segOpType (SegHist lvl
_ SegSpace
space [HistOp rep]
ops [Type]
_ KernelBody rep
_) = do
  HistOp rep
op <- [HistOp rep]
ops
  let shape :: ShapeBase SubExp
shape = forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims forall a. Semigroup a => a -> a -> a
<> forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp rep
op forall a. Semigroup a => a -> a -> a
<> forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp rep
op
  forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) (forall rep. Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op)
  where
    dims :: [SubExp]
dims = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
    segment_dims :: [SubExp]
segment_dims = forall a. [a] -> [a]
init [SubExp]
dims

instance TypedOp (SegOp lvl rep) where
  opType :: forall t (m :: * -> *).
HasScope t m =>
SegOp lvl rep -> m [ExtType]
opType = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall u. [TypeBase (ShapeBase SubExp) u] -> [TypeBase ExtShape u]
staticShapes forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall lvl rep. SegOp lvl rep -> [Type]
segOpType

instance (ASTConstraints lvl, Aliased rep) => AliasedOp (SegOp lvl rep) where
  opAliases :: SegOp lvl rep -> [Names]
opAliases = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const forall a. Monoid a => a
mempty) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall lvl rep. SegOp lvl rep -> [Type]
segOpType

  consumedInOp :: SegOp lvl rep -> Names
consumedInOp (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
kbody) =
    forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
  consumedInOp (SegRed lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
    forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
  consumedInOp (SegScan lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
    forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
  consumedInOp (SegHist lvl
_ SegSpace
_ [HistOp rep]
ops [Type]
_ KernelBody rep
kbody) =
    [VName] -> Names
namesFromList (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall rep. HistOp rep -> [VName]
histDest [HistOp rep]
ops) forall a. Semigroup a => a -> a -> a
<> forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody

-- | Type check a 'SegOp', given a checker for its level.
typeCheckSegOp ::
  TC.Checkable rep =>
  (lvl -> TC.TypeM rep ()) ->
  SegOp lvl (Aliases rep) ->
  TC.TypeM rep ()
typeCheckSegOp :: forall rep lvl.
Checkable rep =>
(lvl -> TypeM rep ()) -> SegOp lvl (Aliases rep) -> TypeM rep ()
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Aliases rep)
kbody) = do
  lvl -> TypeM rep ()
checkLvl lvl
lvl
  forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [] [Type]
ts KernelBody (Aliases rep)
kbody
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegRed lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
reds [Type]
ts KernelBody (Aliases rep)
body) = do
  lvl -> TypeM rep ()
checkLvl lvl
lvl
  forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
reds' [Type]
ts KernelBody (Aliases rep)
body
  where
    reds' :: [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
reds' =
      forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
        (forall a b. (a -> b) -> [a] -> [b]
map forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp (Aliases rep)]
reds)
        (forall a b. (a -> b) -> [a] -> [b]
map forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases rep)]
reds)
        (forall a b. (a -> b) -> [a] -> [b]
map forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape [SegBinOp (Aliases rep)]
reds)
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegScan lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
scans [Type]
ts KernelBody (Aliases rep)
body) = do
  lvl -> TypeM rep ()
checkLvl lvl
lvl
  forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
scans' [Type]
ts KernelBody (Aliases rep)
body
  where
    scans' :: [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
scans' =
      forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
        (forall a b. (a -> b) -> [a] -> [b]
map forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp (Aliases rep)]
scans)
        (forall a b. (a -> b) -> [a] -> [b]
map forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases rep)]
scans)
        (forall a b. (a -> b) -> [a] -> [b]
map forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape [SegBinOp (Aliases rep)]
scans)
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegHist lvl
lvl SegSpace
space [HistOp (Aliases rep)]
ops [Type]
ts KernelBody (Aliases rep)
kbody) = do
  lvl -> TypeM rep ()
checkLvl lvl
lvl
  forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace SegSpace
space
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall rep u.
Checkable rep =>
TypeBase (ShapeBase SubExp) u -> TypeM rep ()
TC.checkType [Type]
ts

  forall rep a.
Checkable rep =>
Scope (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.binding (forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$ do
    [[Type]]
nes_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Aliases rep)]
ops forall a b. (a -> b) -> a -> b
$ \(HistOp ShapeBase SubExp
dest_shape SubExp
rf [VName]
dests [SubExp]
nes ShapeBase SubExp
shape Lambda (Aliases rep)
op) -> do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ShapeBase SubExp
dest_shape
      forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
rf
      [Arg]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape

      -- Operator type must match the type of neutral elements.
      let stripVecDims :: Type -> Type
stripVecDims = forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray forall a b. (a -> b) -> a -> b
$ forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shape
      forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
op forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Arg -> Arg
TC.noArgAliases forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Type -> Type
stripVecDims) forall a b. (a -> b) -> a -> b
$ [Arg]
nes' forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
      let nes_t :: [Type]
nes_t = forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t forall a. Eq a => a -> a -> Bool
== forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op) forall a b. (a -> b) -> a -> b
$
        forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
          forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
            Text
"SegHist operator has return type "
              forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op)
              forall a. Semigroup a => a -> a -> a
<> Text
" but neutral element has type "
              forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
nes_t

      -- Arrays must have proper type.
      let dest_shape' :: ShapeBase SubExp
dest_shape' = forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
dest_shape forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
shape
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
        forall rep. Checkable rep => [Type] -> VName -> TypeM rep ()
TC.requireI [Type
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
dest_shape'] VName
dest
        forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
dest

      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) [Type]
nes_t

    forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts KernelBody (Aliases rep)
kbody

    -- Return type of bucket function must be an index for each
    -- operation followed by the values to write.
    let bucket_ret_t :: [Type]
bucket_ret_t =
          forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((forall a. Int -> a -> [a]
`replicate` forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. ArrayShape a => a -> Int
shapeRank forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. HistOp rep -> ShapeBase SubExp
histShape) [HistOp (Aliases rep)]
ops
            forall a. [a] -> [a] -> [a]
++ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
nes_ts
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t forall a. Eq a => a -> a -> Bool
== [Type]
ts) forall a b. (a -> b) -> a -> b
$
      forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
        forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
          Text
"SegHist body has return type "
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
ts
            forall a. Semigroup a => a -> a -> a
<> Text
" but should have type "
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
bucket_ret_t
  where
    segment_dims :: [SubExp]
segment_dims = forall a. [a] -> [a]
init forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space

checkScanRed ::
  TC.Checkable rep =>
  SegSpace ->
  [(Lambda (Aliases rep), [SubExp], Shape)] ->
  [Type] ->
  KernelBody (Aliases rep) ->
  TC.TypeM rep ()
checkScanRed :: forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
ops [Type]
ts KernelBody (Aliases rep)
kbody = do
  forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace SegSpace
space
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall rep u.
Checkable rep =>
TypeBase (ShapeBase SubExp) u -> TypeM rep ()
TC.checkType [Type]
ts

  forall rep a.
Checkable rep =>
Scope (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.binding (forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$ do
    [[Type]]
ne_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Lambda (Aliases rep), [SubExp], ShapeBase SubExp)]
ops forall a b. (a -> b) -> a -> b
$ \(Lambda (Aliases rep)
lam, [SubExp]
nes, ShapeBase SubExp
shape) -> do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
      [Arg]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes

      -- Operator type must match the type of neutral elements.
      forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases forall a b. (a -> b) -> a -> b
$ [Arg]
nes' forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
      let nes_t :: [Type]
nes_t = forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'

      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam forall a. Eq a => a -> a -> Bool
== [Type]
nes_t) forall a b. (a -> b) -> a -> b
$
        forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
          forall rep. Text -> ErrorCase rep
TC.TypeError Text
"wrong type for operator or neutral elements."

      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shape) [Type]
nes_t

    let expecting :: [Type]
expecting = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
ne_ts
        got :: [Type]
got = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
expecting) [Type]
ts
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
expecting forall a. Eq a => a -> a -> Bool
== [Type]
got) forall a b. (a -> b) -> a -> b
$
      forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
        forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
          Text
"Wrong return for body (does not match neutral elements; expected "
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText [Type]
expecting
            forall a. Semigroup a => a -> a -> a
<> Text
"; found "
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText [Type]
got
            forall a. Semigroup a => a -> a -> a
<> Text
")"

    forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts KernelBody (Aliases rep)
kbody

-- | Like 'Mapper', but just for 'SegOp's.
data SegOpMapper lvl frep trep m = SegOpMapper
  { forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp :: SubExp -> m SubExp,
    forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda :: Lambda frep -> m (Lambda trep),
    forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody :: KernelBody frep -> m (KernelBody trep),
    forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName :: VName -> m VName,
    forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel :: lvl -> m lvl
  }

-- | A mapper that simply returns the 'SegOp' verbatim.
identitySegOpMapper :: Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper :: forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper =
  SegOpMapper
    { mapOnSegOpSubExp :: SubExp -> m SubExp
mapOnSegOpSubExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSegOpLambda :: Lambda rep -> m (Lambda rep)
mapOnSegOpLambda = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSegOpBody :: KernelBody rep -> m (KernelBody rep)
mapOnSegOpBody = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSegOpVName :: VName -> m VName
mapOnSegOpVName = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSegOpLevel :: lvl -> m lvl
mapOnSegOpLevel = forall (f :: * -> *) a. Applicative f => a -> f a
pure
    }

mapOnSegSpace ::
  Monad f => SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace :: forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep f
tv (SegSpace VName
phys [(VName, SubExp)]
dims) =
  VName -> [(VName, SubExp)] -> SegSpace
SegSpace
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep f
tv VName
phys
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep f
tv) (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep f
tv)) [(VName, SubExp)]
dims

mapSegBinOp ::
  Monad m =>
  SegOpMapper lvl frep trep m ->
  SegBinOp frep ->
  m (SegBinOp trep)
mapSegBinOp :: forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv (SegBinOp Commutativity
comm Lambda frep
red_op [SubExp]
nes ShapeBase SubExp
shape) =
  forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda SegOpMapper lvl frep trep m
tv Lambda frep
red_op
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [SubExp]
nes
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall d. [d] -> ShapeBase d
Shape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) (forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape))

-- | Apply a 'SegOpMapper' to the given 'SegOp'.
mapSegOpM ::
  Monad m =>
  SegOpMapper lvl frep trep m ->
  SegOp lvl frep ->
  m (SegOp lvl trep)
mapSegOpM :: forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl frep trep m
tv (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody frep
body) =
  forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl frep trep m
tv) [Type]
ts
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
mapSegOpM SegOpMapper lvl frep trep m
tv (SegRed lvl
lvl SegSpace
space [SegBinOp frep]
reds [Type]
ts KernelBody frep
lam) =
  forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv) [SegBinOp frep]
reds
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType forall a b. (a -> b) -> a -> b
$ forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
lam
mapSegOpM SegOpMapper lvl frep trep m
tv (SegScan lvl
lvl SegSpace
space [SegBinOp frep]
scans [Type]
ts KernelBody frep
body) =
  forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv) [SegBinOp frep]
scans
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType forall a b. (a -> b) -> a -> b
$ forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
mapSegOpM SegOpMapper lvl frep trep m
tv (SegHist lvl
lvl SegSpace
space [HistOp frep]
ops [Type]
ts KernelBody frep
body) =
  forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp frep -> m (HistOp trep)
onHistOp [HistOp frep]
ops
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType forall a b. (a -> b) -> a -> b
$ forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
  where
    onHistOp :: HistOp frep -> m (HistOp trep)
onHistOp (HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
shape Lambda frep
op) =
      forall rep.
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda rep
-> HistOp rep
HistOp
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) ShapeBase SubExp
w
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv SubExp
rf
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep m
tv) [VName]
arrs
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [SubExp]
nes
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall d. [d] -> ShapeBase d
Shape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) (forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape))
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda SegOpMapper lvl frep trep m
tv Lambda frep
op

mapOnSegOpType ::
  Monad m =>
  SegOpMapper lvl frep trep m ->
  Type ->
  m Type
mapOnSegOpType :: forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl frep trep m
_tv t :: Type
t@Prim {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t
mapOnSegOpType SegOpMapper lvl frep trep m
tv (Acc VName
acc ShapeBase SubExp
ispace [Type]
ts NoUniqueness
u) =
  forall shape u.
VName -> ShapeBase SubExp -> [Type] -> u -> TypeBase shape u
Acc
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep m
tv VName
acc
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) ShapeBase SubExp
ispace
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv)) forall (f :: * -> *) a. Applicative f => a -> f a
pure) [Type]
ts
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
mapOnSegOpType SegOpMapper lvl frep trep m
tv (Array PrimType
et ShapeBase SubExp
shape NoUniqueness
u) =
  forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) ShapeBase SubExp
shape forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
mapOnSegOpType SegOpMapper lvl frep trep m
_tv (Mem Space
s) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall shape u. Space -> TypeBase shape u
Mem Space
s

rephraseBinOp ::
  Monad f =>
  Rephraser f from rep ->
  SegBinOp from ->
  f (SegBinOp rep)
rephraseBinOp :: forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> SegBinOp from -> f (SegBinOp rep)
rephraseBinOp Rephraser f from rep
r (SegBinOp Commutativity
comm Lambda from
lam [SubExp]
nes ShapeBase SubExp
shape) =
  forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser f from rep
r Lambda from
lam forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure ShapeBase SubExp
shape

rephraseKernelBody ::
  Monad f =>
  Rephraser f from rep ->
  KernelBody from ->
  f (KernelBody rep)
rephraseKernelBody :: forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser f from rep
r (KernelBody BodyDec from
dec Stms from
stms [KernelResult]
res) =
  forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) from to.
Rephraser m from to -> BodyDec from -> m (BodyDec to)
rephraseBodyDec Rephraser f from rep
r BodyDec from
dec forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm Rephraser f from rep
r) Stms from
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

instance RephraseOp (SegOp lvl) where
  rephraseInOp :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> SegOp lvl from -> m (SegOp lvl to)
rephraseInOp Rephraser m from to
r (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody from
body) =
    forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
space [Type]
ts forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
  rephraseInOp Rephraser m from to
r (SegRed lvl
lvl SegSpace
space [SegBinOp from]
reds [Type]
ts KernelBody from
body) =
    forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
space
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> SegBinOp from -> f (SegBinOp rep)
rephraseBinOp Rephraser m from to
r) [SegBinOp from]
reds
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ts
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
  rephraseInOp Rephraser m from to
r (SegScan lvl
lvl SegSpace
space [SegBinOp from]
scans [Type]
ts KernelBody from
body) =
    forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
space
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> SegBinOp from -> f (SegBinOp rep)
rephraseBinOp Rephraser m from to
r) [SegBinOp from]
scans
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ts
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
  rephraseInOp Rephraser m from to
r (SegHist lvl
lvl SegSpace
space [HistOp from]
hists [Type]
ts KernelBody from
body) =
    forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
space
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp from -> m (HistOp to)
onOp [HistOp from]
hists
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ts
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) from rep.
Monad f =>
Rephraser f from rep -> KernelBody from -> f (KernelBody rep)
rephraseKernelBody Rephraser m from to
r KernelBody from
body
    where
      onOp :: HistOp from -> m (HistOp to)
onOp (HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
shape Lambda from
op) =
        forall rep.
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda rep
-> HistOp rep
HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
shape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
op

-- | A helper for defining 'TraverseOpStms'.
traverseSegOpStms :: Monad m => OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms :: forall (m :: * -> *) lvl rep.
Monad m =>
OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms Scope rep -> Stms rep -> m (Stms rep)
f SegOp lvl rep
segop = forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep m
mapper SegOp lvl rep
segop
  where
    seg_scope :: Scope rep
seg_scope = forall rep. SegSpace -> Scope rep
scopeOfSegSpace (forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
segop)
    f' :: Scope rep -> Stms rep -> m (Stms rep)
f' Scope rep
scope = Scope rep -> Stms rep -> m (Stms rep)
f (Scope rep
seg_scope forall a. Semigroup a => a -> a -> a
<> Scope rep
scope)
    mapper :: SegOpMapper lvl rep rep m
mapper =
      forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda rep -> m (Lambda rep)
mapOnSegOpLambda = forall (m :: * -> *) rep.
Monad m =>
OpStmsTraverser m (Lambda rep) rep
traverseLambdaStms Scope rep -> Stms rep -> m (Stms rep)
f',
          mapOnSegOpBody :: KernelBody rep -> m (KernelBody rep)
mapOnSegOpBody = KernelBody rep -> m (KernelBody rep)
onBody
        }
    onBody :: KernelBody rep -> m (KernelBody rep)
onBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
      forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope rep -> Stms rep -> m (Stms rep)
f Scope rep
seg_scope Stms rep
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

instance
  (ASTRep rep, Substitute lvl) =>
  Substitute (SegOp lvl rep)
  where
  substituteNames :: Map VName VName -> SegOp lvl rep -> SegOp lvl rep
substituteNames Map VName VName
subst = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep Identity
substitute
    where
      substitute :: SegOpMapper lvl rep rep Identity
substitute =
        SegOpMapper
          { mapOnSegOpSubExp :: SubExp -> Identity SubExp
mapOnSegOpSubExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpLambda :: Lambda rep -> Identity (Lambda rep)
mapOnSegOpLambda = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpBody :: KernelBody rep -> Identity (KernelBody rep)
mapOnSegOpBody = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpVName :: VName -> Identity VName
mapOnSegOpVName = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpLevel :: lvl -> Identity lvl
mapOnSegOpLevel = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
          }

instance (ASTRep rep, ASTConstraints lvl) => Rename (SegOp lvl rep) where
  rename :: SegOp lvl rep -> RenameM (SegOp lvl rep)
rename SegOp lvl rep
op =
    forall a. [VName] -> RenameM a -> RenameM a
renameBound (forall k a. Map k a -> [k]
M.keys (forall rep. SegSpace -> Scope rep
scopeOfSegSpace (forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
op))) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep RenameM
renamer SegOp lvl rep
op
    where
      renamer :: SegOpMapper lvl rep rep RenameM
renamer = forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper forall a. Rename a => a -> RenameM a
rename forall a. Rename a => a -> RenameM a
rename forall a. Rename a => a -> RenameM a
rename forall a. Rename a => a -> RenameM a
rename forall a. Rename a => a -> RenameM a
rename

instance (ASTRep rep, FreeIn lvl) => FreeIn (SegOp lvl rep) where
  freeIn' :: SegOp lvl rep -> FV
freeIn' SegOp lvl rep
e =
    Names -> FV -> FV
fvBind ([VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ forall rep. SegSpace -> Scope rep
scopeOfSegSpace (forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
e)) forall a b. (a -> b) -> a -> b
$
      forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s a. State s a -> s -> s
execState forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep (StateT FV Identity)
free SegOp lvl rep
e
    where
      walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Semigroup a => a -> a -> a
<> b -> s
f b
x) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure b
x
      free :: SegOpMapper lvl rep rep (StateT FV Identity)
free =
        SegOpMapper
          { mapOnSegOpSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSegOpSubExp = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpLambda :: Lambda rep -> StateT FV Identity (Lambda rep)
mapOnSegOpLambda = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpBody :: KernelBody rep -> StateT FV Identity (KernelBody rep)
mapOnSegOpBody = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpVName :: VName -> StateT FV Identity VName
mapOnSegOpVName = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpLevel :: lvl -> StateT FV Identity lvl
mapOnSegOpLevel = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn'
          }

instance OpMetrics (Op rep) => OpMetrics (SegOp lvl rep) where
  opMetrics :: SegOp lvl rep -> MetricsM ()
opMetrics (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegMap" forall a b. (a -> b) -> a -> b
$ forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
  opMetrics (SegRed lvl
_ SegSpace
_ [SegBinOp rep]
reds [Type]
_ KernelBody rep
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegRed" forall a b. (a -> b) -> a -> b
$ do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp rep]
reds
      forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
  opMetrics (SegScan lvl
_ SegSpace
_ [SegBinOp rep]
scans [Type]
_ KernelBody rep
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegScan" forall a b. (a -> b) -> a -> b
$ do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp rep]
scans
      forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
  opMetrics (SegHist lvl
_ SegSpace
_ [HistOp rep]
ops [Type]
_ KernelBody rep
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegHist" forall a b. (a -> b) -> a -> b
$ do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. HistOp rep -> Lambda rep
histOp) [HistOp rep]
ops
      forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body

instance Pretty SegSpace where
  pretty :: forall ann. SegSpace -> Doc ann
pretty (SegSpace VName
phys [(VName, SubExp)]
dims) =
    forall a. [Doc a] -> Doc a
apply
      ( do
          (VName
i, SubExp
d) <- [(VName, SubExp)]
dims
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a ann. Pretty a => a -> Doc ann
pretty VName
i forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"<" forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
d
      )
      forall a. Doc a -> Doc a -> Doc a
<+> forall ann. Doc ann -> Doc ann
parens (Doc ann
"~" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty VName
phys)

instance PrettyRep rep => Pretty (SegBinOp rep) where
  pretty :: forall ann. SegBinOp rep -> Doc ann
pretty (SegBinOp Commutativity
comm Lambda rep
lam [SubExp]
nes ShapeBase SubExp
shape) =
    forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
PP.commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
nes) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
      forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty ShapeBase SubExp
shape forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
      forall a. Doc a -> Doc a -> Doc a
</> Doc ann
comm' forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
lam
    where
      comm' :: Doc ann
comm' = case Commutativity
comm of
        Commutativity
Commutative -> Doc ann
"commutative "
        Commutativity
Noncommutative -> forall a. Monoid a => a
mempty

instance (PrettyRep rep, PP.Pretty lvl) => PP.Pretty (SegOp lvl rep) where
  pretty :: forall ann. SegOp lvl rep -> Doc ann
pretty (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody rep
body) =
    Doc ann
"segmap" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
      forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.align (forall a ann. Pretty a => a -> Doc ann
pretty SegSpace
space)
      forall a. Doc a -> Doc a -> Doc a
<+> forall ann. Doc ann
PP.colon
      forall a. Doc a -> Doc a -> Doc a
<+> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Type]
ts)
      forall a. Doc a -> Doc a -> Doc a
<+> forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (forall a ann. Pretty a => a -> Doc ann
pretty KernelBody rep
body)
  pretty (SegRed lvl
lvl SegSpace
space [SegBinOp rep]
reds [Type]
ts KernelBody rep
body) =
    Doc ann
"segred" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
      forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.align (forall a ann. Pretty a => a -> Doc ann
pretty SegSpace
space)
      forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.parens (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
PP.comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SegBinOp rep]
reds)
      forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann
PP.colon
      forall a. Doc a -> Doc a -> Doc a
<+> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Type]
ts)
      forall a. Doc a -> Doc a -> Doc a
<+> forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (forall a ann. Pretty a => a -> Doc ann
pretty KernelBody rep
body)
  pretty (SegScan lvl
lvl SegSpace
space [SegBinOp rep]
scans [Type]
ts KernelBody rep
body) =
    Doc ann
"segscan" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
      forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.align (forall a ann. Pretty a => a -> Doc ann
pretty SegSpace
space)
      forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.parens (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
PP.comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SegBinOp rep]
scans)
      forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann
PP.colon
      forall a. Doc a -> Doc a -> Doc a
<+> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Type]
ts)
      forall a. Doc a -> Doc a -> Doc a
<+> forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (forall a ann. Pretty a => a -> Doc ann
pretty KernelBody rep
body)
  pretty (SegHist lvl
lvl SegSpace
space [HistOp rep]
ops [Type]
ts KernelBody rep
body) =
    Doc ann
"seghist" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty lvl
lvl
      forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.align (forall a ann. Pretty a => a -> Doc ann
pretty SegSpace
space)
      forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.parens (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
PP.comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {rep} {ann}. PrettyRep rep => HistOp rep -> Doc ann
ppOp [HistOp rep]
ops)
      forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann
PP.colon
      forall a. Doc a -> Doc a -> Doc a
<+> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Type]
ts)
      forall a. Doc a -> Doc a -> Doc a
<+> forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (forall a ann. Pretty a => a -> Doc ann
pretty KernelBody rep
body)
    where
      ppOp :: HistOp rep -> Doc ann
ppOp (HistOp ShapeBase SubExp
w SubExp
rf [VName]
dests [SubExp]
nes ShapeBase SubExp
shape Lambda rep
op) =
        forall a ann. Pretty a => a -> Doc ann
pretty ShapeBase SubExp
w forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
          forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
rf forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
          forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
PP.commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [VName]
dests) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
          forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
PP.commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
nes) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty ShapeBase SubExp
shape forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
op

instance CanBeAliased (SegOp lvl) where
  addOpAliases :: forall rep.
AliasableRep rep =>
AliasTable -> SegOp lvl rep -> SegOp lvl (Aliases rep)
addOpAliases AliasTable
aliases = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep (Aliases rep) Identity
alias
    where
      alias :: SegOpMapper lvl rep (Aliases rep) Identity
alias =
        forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
          forall (f :: * -> *) a. Applicative f => a -> f a
pure
          (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases)
          (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep.
AliasableRep rep =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
aliases)
          forall (f :: * -> *) a. Applicative f => a -> f a
pure
          forall (f :: * -> *) a. Applicative f => a -> f a
pure

informKernelBody :: Informing rep => KernelBody rep -> KernelBody (Wise rep)
informKernelBody :: forall rep.
Informing rep =>
KernelBody rep -> KernelBody (Wise rep)
informKernelBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
  forall rep.
Informing rep =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody BodyDec rep
dec (forall rep. Informing rep => Stms rep -> Stms (Wise rep)
informStms Stms rep
stms) [KernelResult]
res

instance CanBeWise (SegOp lvl) where
  addOpWisdom :: forall rep. Informing rep => SegOp lvl rep -> SegOp lvl (Wise rep)
addOpWisdom = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM forall {lvl}. SegOpMapper lvl rep (Wise rep) Identity
add
    where
      add :: SegOpMapper lvl rep (Wise rep) Identity
add =
        forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
          forall (f :: * -> *) a. Applicative f => a -> f a
pure
          (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Informing rep => Lambda rep -> Lambda (Wise rep)
informLambda)
          (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep.
Informing rep =>
KernelBody rep -> KernelBody (Wise rep)
informKernelBody)
          forall (f :: * -> *) a. Applicative f => a -> f a
pure
          forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance ASTRep rep => ST.IndexOp (SegOp lvl rep) where
  indexOp :: forall rep.
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> SegOp lvl rep -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k (SegMap lvl
_ SegSpace
space [Type]
_ KernelBody rep
kbody) [TPrimExp Int64 VName]
is = do
    Returns ResultManifest
ResultMaySimplify Certs
_ SubExp
se <- forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
k forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
    forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids forall a. Ord a => a -> a -> Bool
<= forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
is
    let idx_table :: Map VName Indexed
idx_table = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Certs -> PrimExp VName -> Indexed
ST.Indexed forall a. Monoid a => a
mempty forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) [TPrimExp Int64 VName]
is
        idx_table' :: Map VName Indexed
idx_table' = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map VName Indexed -> Stm rep -> Map VName Indexed
expandIndexedTable Map VName Indexed
idx_table forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
kbody
    case SubExp
se of
      Var VName
v -> forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
idx_table'
      SubExp
_ -> forall a. Maybe a
Nothing
    where
      ([VName]
gtids, [SubExp]
_) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      -- Indexes in excess of what is used to index through the
      -- segment dimensions.
      excess_is :: [TPrimExp Int64 VName]
excess_is = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) [TPrimExp Int64 VName]
is

      expandIndexedTable :: Map VName Indexed -> Stm rep -> Map VName Indexed
expandIndexedTable Map VName Indexed
table Stm rep
stm
        | [VName
v] <- forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
          Just (PrimExp VName
pe, Certs
cs) <-
            forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) rep v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table) forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
            forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Certs -> PrimExp VName -> Indexed
ST.Indexed (forall rep. Stm rep -> Certs
stmCerts Stm rep
stm forall a. Semigroup a => a -> a -> a
<> Certs
cs) PrimExp VName
pe) Map VName Indexed
table
        | [VName
v] <- forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
          BasicOp (Index VName
arr Slice SubExp
slice) <- forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
          forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
excess_is,
          VName
arr forall rep. VName -> SymbolTable rep -> Bool
`ST.available` SymbolTable rep
vtable,
          Just (Slice (PrimExp VName)
slice', Certs
cs) <- Map VName Indexed
-> Slice SubExp -> Maybe (Slice (PrimExp VName), Certs)
asPrimExpSlice Map VName Indexed
table Slice SubExp
slice =
            let idx :: Indexed
idx =
                  Certs -> VName -> [TPrimExp Int64 VName] -> Indexed
ST.IndexedArray
                    (forall rep. Stm rep -> Certs
stmCerts Stm rep
stm forall a. Semigroup a => a -> a -> a
<> Certs
cs)
                    VName
arr
                    (forall d. Num d => Slice d -> [d] -> [d]
fixSlice (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall v. PrimExp v -> TPrimExp Int64 v
isInt64 Slice (PrimExp VName)
slice') [TPrimExp Int64 VName]
excess_is)
             in forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Indexed
idx Map VName Indexed
table
        | Bool
otherwise =
            Map VName Indexed
table

      asPrimExpSlice :: Map VName Indexed
-> Slice SubExp -> Maybe (Slice (PrimExp VName), Certs)
asPrimExpSlice Map VName Indexed
table =
        forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM (Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table))

      asPrimExp :: Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table VName
v
        | Just (ST.Indexed Certs
cs PrimExp VName
e) <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
table = forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certs
cs forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimExp VName
e
        | Just (Prim PrimType
pt) <- forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable =
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
        | Bool
otherwise = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a. Maybe a
Nothing
  indexOp SymbolTable rep
_ Int
_ SegOp lvl rep
_ [TPrimExp Int64 VName]
_ = forall a. Maybe a
Nothing

instance
  (ASTRep rep, ASTConstraints lvl) =>
  IsOp (SegOp lvl rep)
  where
  cheapOp :: SegOp lvl rep -> Bool
cheapOp SegOp lvl rep
_ = Bool
False
  safeOp :: SegOp lvl rep -> Bool
safeOp SegOp lvl rep
_ = Bool
True

--- Simplification

instance Engine.Simplifiable SegSpace where
  simplify :: forall rep. SimplifiableRep rep => SegSpace -> SimpleM rep SegSpace
simplify (SegSpace VName
phys [(VName, SubExp)]
dims) =
    VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
phys forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify) [(VName, SubExp)]
dims

instance Engine.Simplifiable KernelResult where
  simplify :: forall rep.
SimplifiableRep rep =>
KernelResult -> SimpleM rep KernelResult
simplify (Returns ResultManifest
manifest Certs
cs SubExp
what) =
    ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
what
  simplify (WriteReturns Certs
cs ShapeBase SubExp
ws VName
a [(Slice SubExp, SubExp)]
res) =
    Certs
-> ShapeBase SubExp
-> VName
-> [(Slice SubExp, SubExp)]
-> KernelResult
WriteReturns
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
ws
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
a
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(Slice SubExp, SubExp)]
res
  simplify (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
what) =
    Certs -> [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(SubExp, SubExp)]
dims forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
what
  simplify (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
what) =
    Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(SubExp, SubExp, SubExp)]
dims_n_tiles
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
what

mkWiseKernelBody ::
  Informing rep =>
  BodyDec rep ->
  Stms (Wise rep) ->
  [KernelResult] ->
  KernelBody (Wise rep)
mkWiseKernelBody :: forall rep.
Informing rep =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody BodyDec rep
dec Stms (Wise rep)
stms [KernelResult]
res =
  let Body BodyDec (Wise rep)
dec' Stms (Wise rep)
_ Result
_ = forall rep.
Informing rep =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
mkWiseBody BodyDec rep
dec Stms (Wise rep)
stms forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
res_vs
   in forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Wise rep)
dec' Stms (Wise rep)
stms [KernelResult]
res
  where
    res_vs :: [SubExp]
res_vs = forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
res

mkKernelBodyM ::
  MonadBuilder m =>
  Stms (Rep m) ->
  [KernelResult] ->
  m (KernelBody (Rep m))
mkKernelBodyM :: forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms (Rep m)
stms [KernelResult]
kres = do
  Body BodyDec (Rep m)
dec' Stms (Rep m)
_ Result
_ <- forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
stms forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
res_ses
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Rep m)
dec' Stms (Rep m)
stms [KernelResult]
kres
  where
    res_ses :: [SubExp]
res_ses = forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
kres

simplifyKernelBody ::
  (Engine.SimplifiableRep rep, BodyDec rep ~ ()) =>
  SegSpace ->
  KernelBody (Wise rep) ->
  Engine.SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody :: forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space (KernelBody BodyDec (Wise rep)
_ Stms (Wise rep)
stms [KernelResult]
res) = do
  BlockPred (Wise rep)
par_blocker <- forall rep a. (Env rep -> a) -> SimpleM rep a
Engine.asksEngineEnv forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). HoistBlockers rep -> BlockPred (Wise rep)
Engine.blockHoistPar forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Env rep -> HoistBlockers rep
Engine.envHoistBlockers

  let blocker :: BlockPred (Wise rep)
blocker =
        forall rep. ASTRep rep => Names -> BlockPred rep
Engine.hasFree Names
bound_here
          forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` forall rep. BlockPred rep
Engine.isOp
          forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
par_blocker
          forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` forall rep. BlockPred rep
Engine.isConsumed
          forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` forall rep. Aliased rep => BlockPred rep
Engine.isConsuming
          forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` forall rep. SimplifiableRep rep => BlockPred (Wise rep)
Engine.isDeviceMigrated

  -- Ensure we do not try to use anything that is consumed in the result.
  ([KernelResult]
body_res, Stms (Wise rep)
body_stms, Stms (Wise rep)
hoisted) <-
    forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (forall a b c. (a -> b -> c) -> b -> a -> c
flip (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall rep. VName -> SymbolTable rep -> SymbolTable rep
ST.consume)) (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap KernelResult -> [VName]
consumedInResult [KernelResult]
res))
      forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable)
      forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True})
      forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop
      forall a b. (a -> b) -> a -> b
$ forall rep a.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep (a, UsageTable)
-> SimpleM rep (a, Stms (Wise rep), Stms (Wise rep))
Engine.blockIf BlockPred (Wise rep)
blocker Stms (Wise rep)
stms
      forall a b. (a -> b) -> a -> b
$ do
        [KernelResult]
res' <-
          forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (forall rep. Names -> SymbolTable rep -> SymbolTable rep
ST.hideCertified forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms (Wise rep)
stms) forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [KernelResult]
res
        forall (f :: * -> *) a. Applicative f => a -> f a
pure ([KernelResult]
res', Names -> UsageTable
UT.usages forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn [KernelResult]
res')

  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep.
Informing rep =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody () Stms (Wise rep)
body_stms [KernelResult]
body_res, Stms (Wise rep)
hoisted)
  where
    scope_vtable :: SymbolTable (Wise rep)
scope_vtable = forall rep. ASTRep rep => SegSpace -> SymbolTable rep
segSpaceSymbolTable SegSpace
space
    bound_here :: Names
bound_here = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space

    consumedInResult :: KernelResult -> [VName]
consumedInResult (WriteReturns Certs
_ ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) =
      [VName
arr]
    consumedInResult KernelResult
_ =
      []

simplifyLambda ::
  Engine.SimplifiableRep rep =>
  Names ->
  Lambda (Wise rep) ->
  Engine.SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda :: forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda Names
bound = forall rep.
SimplifiableRep rep =>
SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.blockMigrated forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
bound

segSpaceSymbolTable :: ASTRep rep => SegSpace -> ST.SymbolTable rep
segSpaceSymbolTable :: forall rep. ASTRep rep => SegSpace -> SymbolTable rep
segSpaceSymbolTable (SegSpace VName
flat [(VName, SubExp)]
gtids_and_dims) =
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {rep}.
ASTRep rep =>
SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
f (forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope forall a b. (a -> b) -> a -> b
$ forall k a. k -> a -> Map k a
M.singleton VName
flat forall a b. (a -> b) -> a -> b
$ forall rep. IntType -> NameInfo rep
IndexName IntType
Int64) [(VName, SubExp)]
gtids_and_dims
  where
    f :: SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
f SymbolTable rep
vtable (VName
gtid, SubExp
dim) = forall rep.
ASTRep rep =>
VName -> IntType -> SubExp -> SymbolTable rep -> SymbolTable rep
ST.insertLoopVar VName
gtid IntType
Int64 SubExp
dim SymbolTable rep
vtable

simplifySegBinOp ::
  Engine.SimplifiableRep rep =>
  VName ->
  SegBinOp (Wise rep) ->
  Engine.SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp :: forall rep.
SimplifiableRep rep =>
VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp VName
phys_id (SegBinOp Commutativity
comm Lambda (Wise rep)
lam [SubExp]
nes ShapeBase SubExp
shape) = do
  (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <-
    forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True}) forall a b. (a -> b) -> a -> b
$
      forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda (VName -> Names
oneName VName
phys_id) Lambda (Wise rep)
lam
  ShapeBase SubExp
shape' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
shape
  [SubExp]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm Lambda (Wise rep)
lam' [SubExp]
nes' ShapeBase SubExp
shape', Stms (Wise rep)
hoisted)

-- | Simplify the given 'SegOp'.
simplifySegOp ::
  ( Engine.SimplifiableRep rep,
    BodyDec rep ~ (),
    Engine.Simplifiable lvl
  ) =>
  SegOp lvl (Wise rep) ->
  Engine.SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp :: forall rep lvl.
(SimplifiableRep rep, BodyDec rep ~ (), Simplifiable lvl) =>
SegOp lvl (Wise rep)
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Wise rep)
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
  (KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl' SegSpace
space' [Type]
ts' KernelBody (Wise rep)
kbody',
      Stms (Wise rep)
body_hoisted
    )
simplifySegOp (SegRed lvl
lvl SegSpace
space [SegBinOp (Wise rep)]
reds [Type]
ts KernelBody (Wise rep)
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
  ([SegBinOp (Wise rep)]
reds', [Stms (Wise rep)]
reds_hoisted) <-
    forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM (forall rep.
SimplifiableRep rep =>
VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp (SegSpace -> VName
segFlat SegSpace
space)) [SegBinOp (Wise rep)]
reds
  (KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody

  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl' SegSpace
space' [SegBinOp (Wise rep)]
reds' [Type]
ts' KernelBody (Wise rep)
kbody',
      forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
reds_hoisted forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
    )
  where
    scope :: Scope (Wise rep)
scope = forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
    scope_vtable :: SymbolTable (Wise rep)
scope_vtable = forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
simplifySegOp (SegScan lvl
lvl SegSpace
space [SegBinOp (Wise rep)]
scans [Type]
ts KernelBody (Wise rep)
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
  ([SegBinOp (Wise rep)]
scans', [Stms (Wise rep)]
scans_hoisted) <-
    forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM (forall rep.
SimplifiableRep rep =>
VName
-> SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp (SegSpace -> VName
segFlat SegSpace
space)) [SegBinOp (Wise rep)]
scans
  (KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody

  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl' SegSpace
space' [SegBinOp (Wise rep)]
scans' [Type]
ts' KernelBody (Wise rep)
kbody',
      forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
scans_hoisted forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
    )
  where
    scope :: Scope (Wise rep)
scope = forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
    scope_vtable :: SymbolTable (Wise rep)
scope_vtable = forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
simplifySegOp (SegHist lvl
lvl SegSpace
space [HistOp (Wise rep)]
ops [Type]
ts KernelBody (Wise rep)
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)

  ([HistOp (Wise rep)]
ops', [Stms (Wise rep)]
ops_hoisted) <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Wise rep)]
ops forall a b. (a -> b) -> a -> b
$
      \(HistOp ShapeBase SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes ShapeBase SubExp
dims Lambda (Wise rep)
lam) -> do
        ShapeBase SubExp
w' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
w
        SubExp
rf' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
rf
        [VName]
arrs' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
arrs
        [SubExp]
nes' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
        ShapeBase SubExp
dims' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase SubExp
dims
        (Lambda (Wise rep)
lam', Stms (Wise rep)
op_hoisted) <-
          forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) forall a b. (a -> b) -> a -> b
$
            forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True}) forall a b. (a -> b) -> a -> b
$
              forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda (VName -> Names
oneName (SegSpace -> VName
segFlat SegSpace
space)) Lambda (Wise rep)
lam
        forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( forall rep.
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda rep
-> HistOp rep
HistOp ShapeBase SubExp
w' SubExp
rf' [VName]
arrs' [SubExp]
nes' ShapeBase SubExp
dims' Lambda (Wise rep)
lam',
            Stms (Wise rep)
op_hoisted
          )

  (KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody

  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl' SegSpace
space' [HistOp (Wise rep)]
ops' [Type]
ts' KernelBody (Wise rep)
kbody',
      forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
ops_hoisted forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
    )
  where
    scope :: Scope (Wise rep)
scope = forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
    scope_vtable :: SymbolTable (Wise rep)
scope_vtable = forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope

-- | Does this rep contain 'SegOp's in its t'Op's?  A rep must be an
-- instance of this class for the simplification rules to work.
class HasSegOp rep where
  type SegOpLevel rep
  asSegOp :: Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
  segOp :: SegOp (SegOpLevel rep) rep -> Op rep

-- | Simplification rules for simplifying 'SegOp's.
segOpRules ::
  (HasSegOp rep, BuilderOps rep, Buildable rep, Aliased rep) =>
  RuleBook rep
segOpRules :: forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep, Aliased rep) =>
RuleBook rep
segOpRules =
  forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
TopDownRuleOp rep
segOpRuleTopDown] [forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(HasSegOp rep, BuilderOps rep, Aliased rep) =>
BottomUpRuleOp rep
segOpRuleBottomUp]

segOpRuleTopDown ::
  (HasSegOp rep, BuilderOps rep, Buildable rep) =>
  TopDownRuleOp rep
segOpRuleTopDown :: forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
TopDownRuleOp rep
segOpRuleTopDown TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec Op rep
op
  | Just SegOp (SegOpLevel rep) rep
op' <- forall rep.
HasSegOp rep =>
Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
asSegOp Op rep
op =
      forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
SymbolTable rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
topDownSegOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
op'
  | Bool
otherwise =
      forall rep. Rule rep
Skip

segOpRuleBottomUp ::
  (HasSegOp rep, BuilderOps rep, Aliased rep) =>
  BottomUpRuleOp rep
segOpRuleBottomUp :: forall rep.
(HasSegOp rep, BuilderOps rep, Aliased rep) =>
BottomUpRuleOp rep
segOpRuleBottomUp BottomUp rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec Op rep
op
  | Just SegOp (SegOpLevel rep) rep
op' <- forall rep.
HasSegOp rep =>
Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
asSegOp Op rep
op =
      forall rep.
(Aliased rep, HasSegOp rep, BuilderOps rep) =>
(SymbolTable rep, UsageTable)
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
bottomUpSegOp BottomUp rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
op'
  | Bool
otherwise =
      forall rep. Rule rep
Skip

topDownSegOp ::
  (HasSegOp rep, BuilderOps rep, Buildable rep) =>
  ST.SymbolTable rep ->
  Pat (LetDec rep) ->
  StmAux (ExpDec rep) ->
  SegOp (SegOpLevel rep) rep ->
  Rule rep
-- If a SegOp produces something invariant to the SegOp, turn it
-- into a replicate.
topDownSegOp :: forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
SymbolTable rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
topDownSegOp SymbolTable rep
vtable (Pat [PatElem (LetDec rep)]
kpes) StmAux (ExpDec rep)
dec (SegMap SegOpLevel rep
lvl SegSpace
space [Type]
ts (KernelBody BodyDec rep
_ Stms rep
kstms [KernelResult]
kres)) = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
  ([Type]
ts', [PatElem (LetDec rep)]
kpes', [KernelResult]
kres') <-
    forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool
checkForInvarianceResult (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
ts [PatElem (LetDec rep)]
kpes [KernelResult]
kres)

  -- Check if we did anything at all.
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([KernelResult]
kres forall a. Eq a => a -> a -> Bool
== [KernelResult]
kres') forall rep a. RuleM rep a
cannotSimplify

  KernelBody rep
kbody <- forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms rep
kstms [KernelResult]
kres'
  forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$
    forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
kpes') StmAux (ExpDec rep)
dec forall a b. (a -> b) -> a -> b
$
      forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
        forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp forall a b. (a -> b) -> a -> b
$
          forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space [Type]
ts' KernelBody rep
kbody
  where
    isInvariant :: SubExp -> Bool
isInvariant Constant {} = Bool
True
    isInvariant (Var VName
v) = forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
v SymbolTable rep
vtable

    checkForInvarianceResult :: (Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool
checkForInvarianceResult (Type
_, PatElem (LetDec rep)
pe, Returns ResultManifest
rm Certs
cs SubExp
se)
      | Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty,
        ResultManifest
rm forall a. Eq a => a -> a -> Bool
== ResultManifest
ResultMaySimplify,
        SubExp -> Bool
isInvariant SubExp
se = do
          forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] forall a b. (a -> b) -> a -> b
$
            forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
              ShapeBase SubExp -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space) SubExp
se
          forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
    checkForInvarianceResult (Type, PatElem (LetDec rep), KernelResult)
_ =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True

-- If a SegRed contains two reduction operations that have the same
-- vector shape, merge them together.  This saves on communication
-- overhead, but can in principle lead to more local memory usage.
topDownSegOp SymbolTable rep
_ (Pat [PatElem (LetDec rep)]
pes) StmAux (ExpDec rep)
_ (SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
ts KernelBody rep
kbody)
  | forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegBinOp rep]
ops forall a. Ord a => a -> a -> Bool
> Int
1,
    [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings <-
      forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy forall {rep} {b} {rep} {b}.
(SegBinOp rep, b) -> (SegBinOp rep, b) -> Bool
sameShape forall a b. (a -> b) -> a -> b
$
        forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp rep]
ops forall a b. (a -> b) -> a -> b
$
          forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp rep]
ops) forall a b. (a -> b) -> a -> b
$
            forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
red_pes [Type]
red_ts [KernelResult]
red_res,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((forall a. Ord a => a -> a -> Bool
> Int
1) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (t :: * -> *) a. Foldable t => t a -> Int
length) [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      let ([SegBinOp rep]
ops', [[(PatElem (LetDec rep), Type, KernelResult)]]
aux) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {rep} {a}.
Buildable rep =>
[(SegBinOp rep, [a])] -> Maybe (SegBinOp rep, [a])
combineOps [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings
          ([PatElem (LetDec rep)]
red_pes', [Type]
red_ts', [KernelResult]
red_res') = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(PatElem (LetDec rep), Type, KernelResult)]]
aux
          pes' :: [PatElem (LetDec rep)]
pes' = [PatElem (LetDec rep)]
red_pes' forall a. [a] -> [a] -> [a]
++ [PatElem (LetDec rep)]
map_pes
          ts' :: [Type]
ts' = [Type]
red_ts' forall a. [a] -> [a] -> [a]
++ [Type]
map_ts
          kbody' :: KernelBody rep
kbody' = KernelBody rep
kbody {kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult]
red_res' forall a. [a] -> [a] -> [a]
++ [KernelResult]
map_res}
      forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pes') forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops' [Type]
ts' KernelBody rep
kbody'
  where
    ([PatElem (LetDec rep)]
red_pes, [PatElem (LetDec rep)]
map_pes) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) [PatElem (LetDec rep)]
pes
    ([Type]
red_ts, [Type]
map_ts) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) [Type]
ts
    ([KernelResult]
red_res, [KernelResult]
map_res) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody

    sameShape :: (SegBinOp rep, b) -> (SegBinOp rep, b) -> Bool
sameShape (SegBinOp rep
op1, b
_) (SegBinOp rep
op2, b
_) = forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op1 forall a. Eq a => a -> a -> Bool
== forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op2

    combineOps :: [(SegBinOp rep, [a])] -> Maybe (SegBinOp rep, [a])
combineOps [] = forall a. Maybe a
Nothing
    combineOps ((SegBinOp rep, [a])
x : [(SegBinOp rep, [a])]
xs) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {rep} {a}.
Buildable rep =>
(SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
combine (SegBinOp rep, [a])
x [(SegBinOp rep, [a])]
xs

    combine :: (SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
combine (SegBinOp rep
op1, [a]
op1_aux) (SegBinOp rep
op2, [a]
op2_aux) =
      let lam1 :: Lambda rep
lam1 = forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op1
          lam2 :: Lambda rep
lam2 = forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op2
          ([Param (LParamInfo rep)]
op1_xparams, [Param (LParamInfo rep)]
op1_yparams) =
            forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op1)) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam1
          ([Param (LParamInfo rep)]
op2_xparams, [Param (LParamInfo rep)]
op2_yparams) =
            forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op2)) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam2
          lam :: Lambda rep
lam =
            Lambda
              { lambdaParams :: [Param (LParamInfo rep)]
lambdaParams =
                  [Param (LParamInfo rep)]
op1_xparams
                    forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op2_xparams
                    forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op1_yparams
                    forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op2_yparams,
                lambdaReturnType :: [Type]
lambdaReturnType = forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam1 forall a. [a] -> [a] -> [a]
++ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam2,
                lambdaBody :: Body rep
lambdaBody =
                  forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (forall rep. Body rep -> Stms rep
bodyStms (forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam1) forall a. Semigroup a => a -> a -> a
<> forall rep. Body rep -> Stms rep
bodyStms (forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam2)) forall a b. (a -> b) -> a -> b
$
                    forall rep. Body rep -> Result
bodyResult (forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam1) forall a. Semigroup a => a -> a -> a
<> forall rep. Body rep -> Result
bodyResult (forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam2)
              }
       in ( SegBinOp
              { segBinOpComm :: Commutativity
segBinOpComm = forall rep. SegBinOp rep -> Commutativity
segBinOpComm SegBinOp rep
op1 forall a. Semigroup a => a -> a -> a
<> forall rep. SegBinOp rep -> Commutativity
segBinOpComm SegBinOp rep
op2,
                segBinOpLambda :: Lambda rep
segBinOpLambda = Lambda rep
lam,
                segBinOpNeutral :: [SubExp]
segBinOpNeutral = forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op1 forall a. [a] -> [a] -> [a]
++ forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op2,
                segBinOpShape :: ShapeBase SubExp
segBinOpShape = forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp rep
op1 -- Same as shape of op2 due to the grouping.
              },
            [a]
op1_aux forall a. [a] -> [a] -> [a]
++ [a]
op2_aux
          )
topDownSegOp SymbolTable rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ SegOp (SegOpLevel rep) rep
_ = forall rep. Rule rep
Skip

-- A convenient way of operating on the type and body of a SegOp,
-- without worrying about exactly what kind it is.
segOpGuts ::
  SegOp (SegOpLevel rep) rep ->
  ( [Type],
    KernelBody rep,
    Int,
    [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
  )
segOpGuts :: forall rep.
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
    [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
segOpGuts (SegMap SegOpLevel rep
lvl SegSpace
space [Type]
kts KernelBody rep
body) =
  ([Type]
kts, KernelBody rep
body, Int
0, forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space)
segOpGuts (SegScan SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
kts KernelBody rep
body) =
  ([Type]
kts, KernelBody rep
body, forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops, forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops)
segOpGuts (SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
kts KernelBody rep
body) =
  ([Type]
kts, KernelBody rep
body, forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops, forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops)
segOpGuts (SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops [Type]
kts KernelBody rep
body) =
  ([Type]
kts, KernelBody rep
body, forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. HistOp rep -> [VName]
histDest) [HistOp rep]
ops, forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops)

bottomUpSegOp ::
  (Aliased rep, HasSegOp rep, BuilderOps rep) =>
  (ST.SymbolTable rep, UT.UsageTable) ->
  Pat (LetDec rep) ->
  StmAux (ExpDec rep) ->
  SegOp (SegOpLevel rep) rep ->
  Rule rep
-- Some SegOp results can be moved outside the SegOp, which can
-- simplify further analysis.
bottomUpSegOp :: forall rep.
(Aliased rep, HasSegOp rep, BuilderOps rep) =>
(SymbolTable rep, UsageTable)
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
bottomUpSegOp (SymbolTable rep
vtable, UsageTable
used) (Pat [PatElem (LetDec rep)]
kpes) StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
segop = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
  -- Iterate through the bindings.  For each, we check whether it is
  -- in kres and can be moved outside.  If so, we remove it from kres
  -- and kpes and make it a binding outside.  We have to be careful
  -- not to remove anything that is passed on to a scan/map/histogram
  -- operation.  Fortunately, these are always first in the result
  -- list.
  ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') <-
    forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$
      forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stm rep
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
distribute ([PatElem (LetDec rep)]
kpes, [Type]
kts, [KernelResult]
kres, forall a. Monoid a => a
mempty) Stms rep
kstms

  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
    ([PatElem (LetDec rep)]
kpes' forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
kpes)
    forall rep a. RuleM rep a
cannotSimplify

  KernelBody rep
kbody' <-
    forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms rep
kstms' [KernelResult]
kres'

  forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
kpes') StmAux (ExpDec rep)
dec forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp forall a b. (a -> b) -> a -> b
$ [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop [Type]
kts' KernelBody rep
kbody'
  where
    ([Type]
kts, kbody :: KernelBody rep
kbody@(KernelBody BodyDec rep
_ Stms rep
kstms [KernelResult]
kres), Int
num_nonmap_results, [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop) =
      forall rep.
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
    [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
segOpGuts SegOp (SegOpLevel rep) rep
segop
    free_in_kstms :: Names
free_in_kstms = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> Names
freeIn Stms rep
kstms
    consumed_in_segop :: Names
consumed_in_segop = forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
    space :: SegSpace
space = forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp (SegOpLevel rep) rep
segop

    sliceWithGtidsFixed :: Stm rep -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm rep
stm
      | Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (BasicOp (Index VName
arr Slice SubExp
slice)) <- Stm rep
stm,
        [DimIndex SubExp]
space_slice <- forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> SubExp
Var forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space,
        [DimIndex SubExp]
space_slice forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice,
        Slice SubExp
remaining_slice <- forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
space_slice) (forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice),
        forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Maybe a -> Bool
isJust forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup SymbolTable rep
vtable) forall a b. (a -> b) -> a -> b
$
          Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$
            forall a. FreeIn a => a -> Names
freeIn VName
arr forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn Slice SubExp
remaining_slice =
          forall a. a -> Maybe a
Just (Slice SubExp
remaining_slice, VName
arr)
      | Bool
otherwise =
          forall a. Maybe a
Nothing

    distribute :: ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stm rep
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
distribute ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') Stm rep
stm
      | Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ Exp rep
_ <- Stm rep
stm,
        Just (Slice [DimIndex SubExp]
remaining_slice, VName
arr) <- Stm rep -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm rep
stm,
        Just (PatElem (LetDec rep)
kpe, [PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> PatElem (LetDec rep)
-> Maybe
     (PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
      [KernelResult])
isResult [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' PatElem (LetDec rep)
pe = do
          let outer_slice :: [DimIndex SubExp]
outer_slice =
                forall a b. (a -> b) -> [a] -> [b]
map
                  ( \SubExp
d ->
                      forall d. d -> d -> d -> DimIndex d
DimSlice
                        (forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))
                        SubExp
d
                        (forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))
                  )
                  forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
              index :: PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe' =
                forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe'] forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. BasicOp -> Exp rep
BasicOp forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$
                  forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
                    [DimIndex SubExp]
outer_slice forall a. Semigroup a => a -> a -> a
<> [DimIndex SubExp]
remaining_slice
          if forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe
            VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
            Bool -> Bool -> Bool
|| VName
arr
            VName -> Names -> Bool
`nameIn` Names
consumed_in_segop
            then do
              VName
precopy <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe) forall a. Semigroup a => a -> a -> a
<> String
"_precopy"
              PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe {patElemName :: VName
patElemName = VName
precopy}
              forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
precopy
            else PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe
          forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( [PatElem (LetDec rep)]
kpes'',
              [Type]
kts'',
              [KernelResult]
kres'',
              if forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> Names -> Bool
`nameIn` Names
free_in_kstms
                then Stms rep
kstms' forall a. Semigroup a => a -> a -> a
<> forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm
                else Stms rep
kstms'
            )
    distribute ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') Stm rep
stm =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms' forall a. Semigroup a => a -> a -> a
<> forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm)

    isResult :: [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> PatElem (LetDec rep)
-> Maybe
     (PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
      [KernelResult])
isResult [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' PatElem (LetDec rep)
pe =
      case forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (PatElem (LetDec rep), Type, KernelResult) -> Bool
matches forall a b. (a -> b) -> a -> b
$ forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' of
        ([(PatElem (LetDec rep)
kpe, Type
_, KernelResult
_)], [(PatElem (LetDec rep), Type, KernelResult)]
kpes_and_kres)
          | Just Int
i <- forall a. Eq a => a -> [a] -> Maybe Int
elemIndex PatElem (LetDec rep)
kpe [PatElem (LetDec rep)]
kpes,
            Int
i forall a. Ord a => a -> a -> Bool
>= Int
num_nonmap_results,
            ([PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(PatElem (LetDec rep), Type, KernelResult)]
kpes_and_kres ->
              forall a. a -> Maybe a
Just (PatElem (LetDec rep)
kpe, [PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'')
        ([(PatElem (LetDec rep), Type, KernelResult)],
 [(PatElem (LetDec rep), Type, KernelResult)])
_ -> forall a. Maybe a
Nothing
      where
        matches :: (PatElem (LetDec rep), Type, KernelResult) -> Bool
matches (PatElem (LetDec rep)
_, Type
_, Returns ResultManifest
_ Certs
_ (Var VName
v)) = VName
v forall a. Eq a => a -> a -> Bool
== forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe
        matches (PatElem (LetDec rep), Type, KernelResult)
_ = Bool
False

--- Memory

kernelBodyReturns ::
  (Mem rep inner, HasScope rep m, Monad m) =>
  KernelBody somerep ->
  [ExpReturns] ->
  m [ExpReturns]
kernelBodyReturns :: forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns = forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {rep} {inner :: * -> *} {m :: * -> *}.
(FParamInfo rep ~ FParamMem, LParamInfo rep ~ LParamMem,
 RetType rep ~ RetTypeMem, BranchType rep ~ BranchTypeMem,
 OpC rep ~ MemOp inner, Monad m, HasScope rep m,
 HasLetDecMem (LetDec rep), ASTRep rep, OpReturns (inner rep),
 RephraseOp inner) =>
KernelResult -> ExpReturns -> m ExpReturns
correct forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult
  where
    correct :: KernelResult -> ExpReturns -> m ExpReturns
correct (WriteReturns Certs
_ ShapeBase SubExp
_ VName
arr [(Slice SubExp, SubExp)]
_) ExpReturns
_ = forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns VName
arr
    correct KernelResult
_ ExpReturns
ret = forall (f :: * -> *) a. Applicative f => a -> f a
pure ExpReturns
ret

-- | Like 'segOpType', but for memory representations.
segOpReturns ::
  (Mem rep inner, Monad m, HasScope rep m) =>
  SegOp lvl somerep ->
  m [ExpReturns]
segOpReturns :: forall rep (inner :: * -> *) (m :: * -> *) lvl somerep.
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl somerep -> m [ExpReturns]
segOpReturns k :: SegOp lvl somerep
k@(SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody somerep
kbody) =
  forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody somerep
kbody forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl somerep
k
segOpReturns k :: SegOp lvl somerep
k@(SegRed lvl
_ SegSpace
_ [SegBinOp somerep]
_ [Type]
_ KernelBody somerep
kbody) =
  forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody somerep
kbody forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl somerep
k
segOpReturns k :: SegOp lvl somerep
k@(SegScan lvl
_ SegSpace
_ [SegBinOp somerep]
_ [Type]
_ KernelBody somerep
kbody) =
  forall rep (inner :: * -> *) (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody somerep
kbody forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl somerep
k
segOpReturns (SegHist lvl
_ SegSpace
_ [HistOp somerep]
ops [Type]
_ KernelBody somerep
_) =
  forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. HistOp rep -> [VName]
histDest) [HistOp somerep]
ops