{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Segmented operations.  These correspond to perfect @map@ nests on
-- top of /something/, except that the @map@s are conceptually only
-- over @iota@s (so there will be explicit indexing inside them).
module Futhark.IR.SegOp
  ( SegOp (..),
    SegVirt (..),
    SegSeqDims (..),
    segLevel,
    segBody,
    segSpace,
    typeCheckSegOp,
    SegSpace (..),
    scopeOfSegSpace,
    segSpaceDims,

    -- * Details
    HistOp (..),
    histType,
    splitHistResults,
    SegBinOp (..),
    segBinOpResults,
    segBinOpChunks,
    KernelBody (..),
    aliasAnalyseKernelBody,
    consumedInKernelBody,
    ResultManifest (..),
    KernelResult (..),
    kernelResultCerts,
    kernelResultSubExp,
    SplitOrdering (..),

    -- ** Generic traversal
    SegOpMapper (..),
    identitySegOpMapper,
    mapSegOpM,
    traverseSegOpStms,

    -- * Simplification
    simplifySegOp,
    HasSegOp (..),
    segOpRules,

    -- * Memory
    segOpReturns,
  )
where

import Control.Category
import Control.Monad.Identity hiding (mapM_)
import Control.Monad.Reader hiding (mapM_)
import Control.Monad.State.Strict
import Control.Monad.Writer hiding (mapM_)
import Data.Bifunctor (first)
import Data.Bitraversable
import Data.Foldable (traverse_)
import Data.List
  ( elemIndex,
    foldl',
    groupBy,
    intersperse,
    isPrefixOf,
    partition,
  )
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Futhark.Analysis.Alias as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.IR
import Futhark.IR.Aliases
  ( Aliases,
    removeLambdaAliases,
    removeStmAliases,
  )
import Futhark.IR.Mem
import Futhark.IR.Prop.Aliases
import qualified Futhark.IR.TypeCheck as TC
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (chunks, maybeNth)
import Futhark.Util.Pretty
  ( Pretty,
    commasep,
    parens,
    ppr,
    text,
    (<+>),
    (</>),
  )
import qualified Futhark.Util.Pretty as PP
import Prelude hiding (id, (.))

-- | How an array is split into chunks.
data SplitOrdering
  = SplitContiguous
  | SplitStrided SubExp
  deriving (SplitOrdering -> SplitOrdering -> Bool
(SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool) -> Eq SplitOrdering
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SplitOrdering -> SplitOrdering -> Bool
$c/= :: SplitOrdering -> SplitOrdering -> Bool
== :: SplitOrdering -> SplitOrdering -> Bool
$c== :: SplitOrdering -> SplitOrdering -> Bool
Eq, Eq SplitOrdering
Eq SplitOrdering
-> (SplitOrdering -> SplitOrdering -> Ordering)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> SplitOrdering)
-> (SplitOrdering -> SplitOrdering -> SplitOrdering)
-> Ord SplitOrdering
SplitOrdering -> SplitOrdering -> Bool
SplitOrdering -> SplitOrdering -> Ordering
SplitOrdering -> SplitOrdering -> SplitOrdering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SplitOrdering -> SplitOrdering -> SplitOrdering
$cmin :: SplitOrdering -> SplitOrdering -> SplitOrdering
max :: SplitOrdering -> SplitOrdering -> SplitOrdering
$cmax :: SplitOrdering -> SplitOrdering -> SplitOrdering
>= :: SplitOrdering -> SplitOrdering -> Bool
$c>= :: SplitOrdering -> SplitOrdering -> Bool
> :: SplitOrdering -> SplitOrdering -> Bool
$c> :: SplitOrdering -> SplitOrdering -> Bool
<= :: SplitOrdering -> SplitOrdering -> Bool
$c<= :: SplitOrdering -> SplitOrdering -> Bool
< :: SplitOrdering -> SplitOrdering -> Bool
$c< :: SplitOrdering -> SplitOrdering -> Bool
compare :: SplitOrdering -> SplitOrdering -> Ordering
$ccompare :: SplitOrdering -> SplitOrdering -> Ordering
$cp1Ord :: Eq SplitOrdering
Ord, Int -> SplitOrdering -> ShowS
[SplitOrdering] -> ShowS
SplitOrdering -> String
(Int -> SplitOrdering -> ShowS)
-> (SplitOrdering -> String)
-> ([SplitOrdering] -> ShowS)
-> Show SplitOrdering
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SplitOrdering] -> ShowS
$cshowList :: [SplitOrdering] -> ShowS
show :: SplitOrdering -> String
$cshow :: SplitOrdering -> String
showsPrec :: Int -> SplitOrdering -> ShowS
$cshowsPrec :: Int -> SplitOrdering -> ShowS
Show)

instance FreeIn SplitOrdering where
  freeIn' :: SplitOrdering -> FV
freeIn' SplitOrdering
SplitContiguous = FV
forall a. Monoid a => a
mempty
  freeIn' (SplitStrided SubExp
stride) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
stride

instance Substitute SplitOrdering where
  substituteNames :: Map VName VName -> SplitOrdering -> SplitOrdering
substituteNames Map VName VName
_ SplitOrdering
SplitContiguous =
    SplitOrdering
SplitContiguous
  substituteNames Map VName VName
subst (SplitStrided SubExp
stride) =
    SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
stride

instance Rename SplitOrdering where
  rename :: SplitOrdering -> RenameM SplitOrdering
rename SplitOrdering
SplitContiguous =
    SplitOrdering -> RenameM SplitOrdering
forall (f :: * -> *) a. Applicative f => a -> f a
pure SplitOrdering
SplitContiguous
  rename (SplitStrided SubExp
stride) =
    SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering)
-> RenameM SubExp -> RenameM SplitOrdering
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
stride

-- | An operator for 'SegHist'.
data HistOp rep = HistOp
  { HistOp rep -> Shape
histShape :: Shape,
    HistOp rep -> SubExp
histRaceFactor :: SubExp,
    HistOp rep -> [VName]
histDest :: [VName],
    HistOp rep -> [SubExp]
histNeutral :: [SubExp],
    -- | In case this operator is semantically a vectorised
    -- operator (corresponding to a perfect map nest in the
    -- SOACS representation), these are the logical
    -- "dimensions".  This is used to generate more efficient
    -- code.
    HistOp rep -> Shape
histOpShape :: Shape,
    HistOp rep -> Lambda rep
histOp :: Lambda rep
  }
  deriving (HistOp rep -> HistOp rep -> Bool
(HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool) -> Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HistOp rep -> HistOp rep -> Bool
$c/= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
== :: HistOp rep -> HistOp rep -> Bool
$c== :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
Eq, Eq (HistOp rep)
Eq (HistOp rep)
-> (HistOp rep -> HistOp rep -> Ordering)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> HistOp rep)
-> (HistOp rep -> HistOp rep -> HistOp rep)
-> Ord (HistOp rep)
HistOp rep -> HistOp rep -> Bool
HistOp rep -> HistOp rep -> Ordering
HistOp rep -> HistOp rep -> HistOp rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
min :: HistOp rep -> HistOp rep -> HistOp rep
$cmin :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
max :: HistOp rep -> HistOp rep -> HistOp rep
$cmax :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
>= :: HistOp rep -> HistOp rep -> Bool
$c>= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
> :: HistOp rep -> HistOp rep -> Bool
$c> :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
<= :: HistOp rep -> HistOp rep -> Bool
$c<= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
< :: HistOp rep -> HistOp rep -> Bool
$c< :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
compare :: HistOp rep -> HistOp rep -> Ordering
$ccompare :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
$cp1Ord :: forall rep. RepTypes rep => Eq (HistOp rep)
Ord, Int -> HistOp rep -> ShowS
[HistOp rep] -> ShowS
HistOp rep -> String
(Int -> HistOp rep -> ShowS)
-> (HistOp rep -> String)
-> ([HistOp rep] -> ShowS)
-> Show (HistOp rep)
forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
forall rep. RepTypes rep => [HistOp rep] -> ShowS
forall rep. RepTypes rep => HistOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HistOp rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [HistOp rep] -> ShowS
show :: HistOp rep -> String
$cshow :: forall rep. RepTypes rep => HistOp rep -> String
showsPrec :: Int -> HistOp rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
Show)

-- | The type of a histogram produced by a 'HistOp'.  This can be
-- different from the type of the 'histDest's in case we are
-- dealing with a segmented histogram.
histType :: HistOp rep -> [Type]
histType :: HistOp rep -> [Type]
histType HistOp rep
op =
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` (HistOp rep -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp rep
op Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> HistOp rep -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp rep
op)) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
    Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$
      HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op

-- | Split reduction results returned by a 'KernelBody' into those
-- that correspond to indexes for the 'HistOp's, and those that
-- correspond to value.
splitHistResults :: [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults :: [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])]
splitHistResults [HistOp rep]
ops [SubExp]
res =
  let ranks :: [Int]
ranks = (HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (Shape -> Int) -> (HistOp rep -> Shape) -> HistOp rep -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> Shape
forall rep. HistOp rep -> Shape
histShape) [HistOp rep]
ops
      ([SubExp]
idxs, [SubExp]
vals) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
ranks) [SubExp]
res
   in [[SubExp]] -> [[SubExp]] -> [([SubExp], [SubExp])]
forall a b. [a] -> [b] -> [(a, b)]
zip
        ([Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ranks [SubExp]
idxs)
        ([Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (HistOp rep -> [VName]) -> HistOp rep -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp rep]
ops) [SubExp]
vals)

-- | An operator for 'SegScan' and 'SegRed'.
data SegBinOp rep = SegBinOp
  { SegBinOp rep -> Commutativity
segBinOpComm :: Commutativity,
    SegBinOp rep -> Lambda rep
segBinOpLambda :: Lambda rep,
    SegBinOp rep -> [SubExp]
segBinOpNeutral :: [SubExp],
    -- | In case this operator is semantically a vectorised
    -- operator (corresponding to a perfect map nest in the
    -- SOACS representation), these are the logical
    -- "dimensions".  This is used to generate more efficient
    -- code.
    SegBinOp rep -> Shape
segBinOpShape :: Shape
  }
  deriving (SegBinOp rep -> SegBinOp rep -> Bool
(SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool) -> Eq (SegBinOp rep)
forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegBinOp rep -> SegBinOp rep -> Bool
$c/= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
== :: SegBinOp rep -> SegBinOp rep -> Bool
$c== :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
Eq, Eq (SegBinOp rep)
Eq (SegBinOp rep)
-> (SegBinOp rep -> SegBinOp rep -> Ordering)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> Bool)
-> (SegBinOp rep -> SegBinOp rep -> SegBinOp rep)
-> (SegBinOp rep -> SegBinOp rep -> SegBinOp rep)
-> Ord (SegBinOp rep)
SegBinOp rep -> SegBinOp rep -> Bool
SegBinOp rep -> SegBinOp rep -> Ordering
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (SegBinOp rep)
forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Ordering
forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
min :: SegBinOp rep -> SegBinOp rep -> SegBinOp rep
$cmin :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
max :: SegBinOp rep -> SegBinOp rep -> SegBinOp rep
$cmax :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> SegBinOp rep
>= :: SegBinOp rep -> SegBinOp rep -> Bool
$c>= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
> :: SegBinOp rep -> SegBinOp rep -> Bool
$c> :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
<= :: SegBinOp rep -> SegBinOp rep -> Bool
$c<= :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
< :: SegBinOp rep -> SegBinOp rep -> Bool
$c< :: forall rep. RepTypes rep => SegBinOp rep -> SegBinOp rep -> Bool
compare :: SegBinOp rep -> SegBinOp rep -> Ordering
$ccompare :: forall rep.
RepTypes rep =>
SegBinOp rep -> SegBinOp rep -> Ordering
$cp1Ord :: forall rep. RepTypes rep => Eq (SegBinOp rep)
Ord, Int -> SegBinOp rep -> ShowS
[SegBinOp rep] -> ShowS
SegBinOp rep -> String
(Int -> SegBinOp rep -> ShowS)
-> (SegBinOp rep -> String)
-> ([SegBinOp rep] -> ShowS)
-> Show (SegBinOp rep)
forall rep. RepTypes rep => Int -> SegBinOp rep -> ShowS
forall rep. RepTypes rep => [SegBinOp rep] -> ShowS
forall rep. RepTypes rep => SegBinOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegBinOp rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [SegBinOp rep] -> ShowS
show :: SegBinOp rep -> String
$cshow :: forall rep. RepTypes rep => SegBinOp rep -> String
showsPrec :: Int -> SegBinOp rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> SegBinOp rep -> ShowS
Show)

-- | How many reduction results are produced by these 'SegBinOp's?
segBinOpResults :: [SegBinOp rep] -> Int
segBinOpResults :: [SegBinOp rep] -> Int
segBinOpResults = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int)
-> ([SegBinOp rep] -> [Int]) -> [SegBinOp rep] -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SegBinOp rep -> Int) -> [SegBinOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp rep -> [SubExp]) -> SegBinOp rep -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral)

-- | Split some list into chunks equal to the number of values
-- returned by each 'SegBinOp'
segBinOpChunks :: [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks :: [SegBinOp rep] -> [a] -> [[a]]
segBinOpChunks = [Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [a] -> [[a]])
-> ([SegBinOp rep] -> [Int]) -> [SegBinOp rep] -> [a] -> [[a]]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SegBinOp rep -> Int) -> [SegBinOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp rep -> [SubExp]) -> SegBinOp rep -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral)

-- | The body of a 'SegOp'.
data KernelBody rep = KernelBody
  { KernelBody rep -> BodyDec rep
kernelBodyDec :: BodyDec rep,
    KernelBody rep -> Stms rep
kernelBodyStms :: Stms rep,
    KernelBody rep -> [KernelResult]
kernelBodyResult :: [KernelResult]
  }

deriving instance RepTypes rep => Ord (KernelBody rep)

deriving instance RepTypes rep => Show (KernelBody rep)

deriving instance RepTypes rep => Eq (KernelBody rep)

-- | Metadata about whether there is a subtle point to this
-- 'KernelResult'.  This is used to protect things like tiling, which
-- might otherwise be removed by the simplifier because they're
-- semantically redundant.  This has no semantic effect and can be
-- ignored at code generation.
data ResultManifest
  = -- | Don't simplify this one!
    ResultNoSimplify
  | -- | Go nuts.
    ResultMaySimplify
  | -- | The results produced are only used within the
    -- same physical thread later on, and can thus be
    -- kept in registers.
    ResultPrivate
  deriving (ResultManifest -> ResultManifest -> Bool
(ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool) -> Eq ResultManifest
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ResultManifest -> ResultManifest -> Bool
$c/= :: ResultManifest -> ResultManifest -> Bool
== :: ResultManifest -> ResultManifest -> Bool
$c== :: ResultManifest -> ResultManifest -> Bool
Eq, Int -> ResultManifest -> ShowS
[ResultManifest] -> ShowS
ResultManifest -> String
(Int -> ResultManifest -> ShowS)
-> (ResultManifest -> String)
-> ([ResultManifest] -> ShowS)
-> Show ResultManifest
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ResultManifest] -> ShowS
$cshowList :: [ResultManifest] -> ShowS
show :: ResultManifest -> String
$cshow :: ResultManifest -> String
showsPrec :: Int -> ResultManifest -> ShowS
$cshowsPrec :: Int -> ResultManifest -> ShowS
Show, Eq ResultManifest
Eq ResultManifest
-> (ResultManifest -> ResultManifest -> Ordering)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> Ord ResultManifest
ResultManifest -> ResultManifest -> Bool
ResultManifest -> ResultManifest -> Ordering
ResultManifest -> ResultManifest -> ResultManifest
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ResultManifest -> ResultManifest -> ResultManifest
$cmin :: ResultManifest -> ResultManifest -> ResultManifest
max :: ResultManifest -> ResultManifest -> ResultManifest
$cmax :: ResultManifest -> ResultManifest -> ResultManifest
>= :: ResultManifest -> ResultManifest -> Bool
$c>= :: ResultManifest -> ResultManifest -> Bool
> :: ResultManifest -> ResultManifest -> Bool
$c> :: ResultManifest -> ResultManifest -> Bool
<= :: ResultManifest -> ResultManifest -> Bool
$c<= :: ResultManifest -> ResultManifest -> Bool
< :: ResultManifest -> ResultManifest -> Bool
$c< :: ResultManifest -> ResultManifest -> Bool
compare :: ResultManifest -> ResultManifest -> Ordering
$ccompare :: ResultManifest -> ResultManifest -> Ordering
$cp1Ord :: Eq ResultManifest
Ord)

-- | A 'KernelBody' does not return an ordinary 'Result'.  Instead, it
-- returns a list of these.
data KernelResult
  = -- | Each "worker" in the kernel returns this.
    -- Whether this is a result-per-thread or a
    -- result-per-group depends on where the 'SegOp' occurs.
    Returns ResultManifest Certs SubExp
  | WriteReturns
      Certs
      Shape -- Size of array.  Must match number of dims.
      VName -- Which array
      [(Slice SubExp, SubExp)]
  | -- Arbitrary number of index/value pairs.
    ConcatReturns
      Certs
      SplitOrdering -- Permuted?
      SubExp -- The final size.
      SubExp -- Per-thread/group (max) chunk size.
      VName -- Chunk by this worker.
  | TileReturns
      Certs
      [(SubExp, SubExp)] -- Total/tile for each dimension
      VName -- Tile written by this worker.
      -- The TileReturns must not expect more than one
      -- result to be written per physical thread.
  | RegTileReturns
      Certs
      -- For each dim of result:
      [ ( SubExp, -- size of this dim.
          SubExp, -- block tile size for this dim.
          SubExp -- reg tile size for this dim.
        )
      ]
      VName -- Tile returned by this worker/group.
  deriving (KernelResult -> KernelResult -> Bool
(KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool) -> Eq KernelResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelResult -> KernelResult -> Bool
$c/= :: KernelResult -> KernelResult -> Bool
== :: KernelResult -> KernelResult -> Bool
$c== :: KernelResult -> KernelResult -> Bool
Eq, Int -> KernelResult -> ShowS
[KernelResult] -> ShowS
KernelResult -> String
(Int -> KernelResult -> ShowS)
-> (KernelResult -> String)
-> ([KernelResult] -> ShowS)
-> Show KernelResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelResult] -> ShowS
$cshowList :: [KernelResult] -> ShowS
show :: KernelResult -> String
$cshow :: KernelResult -> String
showsPrec :: Int -> KernelResult -> ShowS
$cshowsPrec :: Int -> KernelResult -> ShowS
Show, Eq KernelResult
Eq KernelResult
-> (KernelResult -> KernelResult -> Ordering)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> KernelResult)
-> (KernelResult -> KernelResult -> KernelResult)
-> Ord KernelResult
KernelResult -> KernelResult -> Bool
KernelResult -> KernelResult -> Ordering
KernelResult -> KernelResult -> KernelResult
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelResult -> KernelResult -> KernelResult
$cmin :: KernelResult -> KernelResult -> KernelResult
max :: KernelResult -> KernelResult -> KernelResult
$cmax :: KernelResult -> KernelResult -> KernelResult
>= :: KernelResult -> KernelResult -> Bool
$c>= :: KernelResult -> KernelResult -> Bool
> :: KernelResult -> KernelResult -> Bool
$c> :: KernelResult -> KernelResult -> Bool
<= :: KernelResult -> KernelResult -> Bool
$c<= :: KernelResult -> KernelResult -> Bool
< :: KernelResult -> KernelResult -> Bool
$c< :: KernelResult -> KernelResult -> Bool
compare :: KernelResult -> KernelResult -> Ordering
$ccompare :: KernelResult -> KernelResult -> Ordering
$cp1Ord :: Eq KernelResult
Ord)

-- | Get the certs for this 'KernelResult'.
kernelResultCerts :: KernelResult -> Certs
kernelResultCerts :: KernelResult -> Certs
kernelResultCerts (Returns ResultManifest
_ Certs
cs SubExp
_) = Certs
cs
kernelResultCerts (WriteReturns Certs
cs Shape
_ VName
_ [(Slice SubExp, SubExp)]
_) = Certs
cs
kernelResultCerts (ConcatReturns Certs
cs SplitOrdering
_ SubExp
_ SubExp
_ VName
_) = Certs
cs
kernelResultCerts (TileReturns Certs
cs [(SubExp, SubExp)]
_ VName
_) = Certs
cs
kernelResultCerts (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
_ VName
_) = Certs
cs

-- | Get the root t'SubExp' corresponding values for a 'KernelResult'.
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp (Returns ResultManifest
_ Certs
_ SubExp
se) = SubExp
se
kernelResultSubExp (WriteReturns Certs
_ Shape
_ VName
arr [(Slice SubExp, SubExp)]
_) = VName -> SubExp
Var VName
arr
kernelResultSubExp (ConcatReturns Certs
_ SplitOrdering
_ SubExp
_ SubExp
_ VName
v) = VName -> SubExp
Var VName
v
kernelResultSubExp (TileReturns Certs
_ [(SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v
kernelResultSubExp (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v

instance FreeIn KernelResult where
  freeIn' :: KernelResult -> FV
freeIn' (Returns ResultManifest
_ Certs
cs SubExp
what) = Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
what
  freeIn' (WriteReturns Certs
cs Shape
rws VName
arr [(Slice SubExp, SubExp)]
res) = Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Shape -> FV
forall a. FreeIn a => a -> FV
freeIn' Shape
rws FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
arr FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(Slice SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(Slice SubExp, SubExp)]
res
  freeIn' (ConcatReturns Certs
cs SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) =
    Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SplitOrdering -> FV
forall a. FreeIn a => a -> FV
freeIn' SplitOrdering
o FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
per_thread_elems FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v
  freeIn' (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
    Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp)]
dims FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v
  freeIn' (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
    Certs -> FV
forall a. FreeIn a => a -> FV
freeIn' Certs
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(SubExp, SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp, SubExp)]
dims_n_tiles FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v

instance ASTRep rep => FreeIn (KernelBody rep) where
  freeIn' :: KernelBody rep -> FV
freeIn' (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
    Names -> FV -> FV
fvBind Names
bound_in_stms (FV -> FV) -> FV -> FV
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> FV
forall a. FreeIn a => a -> FV
freeIn' BodyDec rep
dec FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Stms rep -> FV
forall a. FreeIn a => a -> FV
freeIn' Stms rep
stms FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [KernelResult] -> FV
forall a. FreeIn a => a -> FV
freeIn' [KernelResult]
res
    where
      bound_in_stms :: Names
bound_in_stms = (Stm rep -> Names) -> Stms rep -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm rep -> Names
forall rep. Stm rep -> Names
boundByStm Stms rep
stms

instance ASTRep rep => Substitute (KernelBody rep) where
  substituteNames :: Map VName VName -> KernelBody rep -> KernelBody rep
substituteNames Map VName VName
subst (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
    BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody
      (Map VName VName -> BodyDec rep -> BodyDec rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst BodyDec rep
dec)
      (Map VName VName -> Stms rep -> Stms rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Stms rep
stms)
      (Map VName VName -> [KernelResult] -> [KernelResult]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [KernelResult]
res)

instance Substitute KernelResult where
  substituteNames :: Map VName VName -> KernelResult -> KernelResult
substituteNames Map VName VName
subst (Returns ResultManifest
manifest Certs
cs SubExp
se) =
    ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs) (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
se)
  substituteNames Map VName VName
subst (WriteReturns Certs
cs Shape
rws VName
arr [(Slice SubExp, SubExp)]
res) =
    Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns
      (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
      (Map VName VName -> Shape -> Shape
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Shape
rws)
      (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
arr)
      (Map VName VName
-> [(Slice SubExp, SubExp)] -> [(Slice SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(Slice SubExp, SubExp)]
res)
  substituteNames Map VName VName
subst (ConcatReturns Certs
cs SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) =
    Certs -> SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns
      (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
      (Map VName VName -> SplitOrdering -> SplitOrdering
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SplitOrdering
o)
      (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
w)
      (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
per_thread_elems)
      (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
  substituteNames Map VName VName
subst (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
    Certs -> [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns
      (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
      (Map VName VName -> [(SubExp, SubExp)] -> [(SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp)]
dims)
      (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
  substituteNames Map VName VName
subst (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
    Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
      (Map VName VName -> Certs -> Certs
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Certs
cs)
      (Map VName VName
-> [(SubExp, SubExp, SubExp)] -> [(SubExp, SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp, SubExp)]
dims_n_tiles)
      (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)

instance ASTRep rep => Rename (KernelBody rep) where
  rename :: KernelBody rep -> RenameM (KernelBody rep)
rename (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) = do
    BodyDec rep
dec' <- BodyDec rep -> RenameM (BodyDec rep)
forall a. Rename a => a -> RenameM a
rename BodyDec rep
dec
    Stms rep
-> (Stms rep -> RenameM (KernelBody rep))
-> RenameM (KernelBody rep)
forall rep a.
Renameable rep =>
Stms rep -> (Stms rep -> RenameM a) -> RenameM a
renamingStms Stms rep
stms ((Stms rep -> RenameM (KernelBody rep))
 -> RenameM (KernelBody rep))
-> (Stms rep -> RenameM (KernelBody rep))
-> RenameM (KernelBody rep)
forall a b. (a -> b) -> a -> b
$ \Stms rep
stms' ->
      BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec' Stms rep
stms' ([KernelResult] -> KernelBody rep)
-> RenameM [KernelResult] -> RenameM (KernelBody rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [KernelResult] -> RenameM [KernelResult]
forall a. Rename a => a -> RenameM a
rename [KernelResult]
res

instance Rename KernelResult where
  rename :: KernelResult -> RenameM KernelResult
rename = KernelResult -> RenameM KernelResult
forall a. Substitute a => a -> RenameM a
substituteRename

-- | Perform alias analysis on a 'KernelBody'.
aliasAnalyseKernelBody ::
  ( ASTRep rep,
    CanBeAliased (Op rep)
  ) =>
  AliasTable ->
  KernelBody rep ->
  KernelBody (Aliases rep)
aliasAnalyseKernelBody :: AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
aliases (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
  let Body BodyDec (Aliases rep)
dec' Stms (Aliases rep)
stms' Result
_ = AliasTable -> Body rep -> Body (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Body rep -> Body (Aliases rep)
Alias.analyseBody AliasTable
aliases (Body rep -> Body (Aliases rep)) -> Body rep -> Body (Aliases rep)
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms []
   in BodyDec (Aliases rep)
-> Stms (Aliases rep) -> [KernelResult] -> KernelBody (Aliases rep)
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Aliases rep)
dec' Stms (Aliases rep)
stms' [KernelResult]
res

removeKernelBodyAliases ::
  CanBeAliased (Op rep) =>
  KernelBody (Aliases rep) ->
  KernelBody rep
removeKernelBodyAliases :: KernelBody (Aliases rep) -> KernelBody rep
removeKernelBodyAliases (KernelBody (_, dec) Stms (Aliases rep)
stms [KernelResult]
res) =
  BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec ((Stm (Aliases rep) -> Stm rep) -> Stms (Aliases rep) -> Stms rep
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Aliases rep) -> Stm rep
forall rep. CanBeAliased (Op rep) => Stm (Aliases rep) -> Stm rep
removeStmAliases Stms (Aliases rep)
stms) [KernelResult]
res

removeKernelBodyWisdom ::
  CanBeWise (Op rep) =>
  KernelBody (Wise rep) ->
  KernelBody rep
removeKernelBodyWisdom :: KernelBody (Wise rep) -> KernelBody rep
removeKernelBodyWisdom (KernelBody BodyDec (Wise rep)
dec Stms (Wise rep)
stms [KernelResult]
res) =
  let Body BodyDec rep
dec' Stms rep
stms' Result
_ = Body (Wise rep) -> Body rep
forall rep. CanBeWise (Op rep) => Body (Wise rep) -> Body rep
removeBodyWisdom (Body (Wise rep) -> Body rep) -> Body (Wise rep) -> Body rep
forall a b. (a -> b) -> a -> b
$ BodyDec (Wise rep) -> Stms (Wise rep) -> Result -> Body (Wise rep)
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec (Wise rep)
dec Stms (Wise rep)
stms []
   in BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec' Stms rep
stms' [KernelResult]
res

-- | The variables consumed in the kernel body.
consumedInKernelBody ::
  Aliased rep =>
  KernelBody rep ->
  Names
consumedInKernelBody :: KernelBody rep -> Names
consumedInKernelBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
  Body rep -> Names
forall rep. Aliased rep => Body rep -> Names
consumedInBody (BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms []) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((KernelResult -> Names) -> [KernelResult] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Names
consumedByReturn [KernelResult]
res)
  where
    consumedByReturn :: KernelResult -> Names
consumedByReturn (WriteReturns Certs
_ Shape
_ VName
a [(Slice SubExp, SubExp)]
_) = VName -> Names
oneName VName
a
    consumedByReturn KernelResult
_ = Names
forall a. Monoid a => a
mempty

checkKernelBody ::
  TC.Checkable rep =>
  [Type] ->
  KernelBody (Aliases rep) ->
  TC.TypeM rep ()
checkKernelBody :: [Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts (KernelBody (_, dec) Stms (Aliases rep)
stms [KernelResult]
kres) = do
  BodyDec rep -> TypeM rep ()
forall rep. Checkable rep => BodyDec rep -> TypeM rep ()
TC.checkBodyDec BodyDec rep
dec
  -- We consume the kernel results (when applicable) before
  -- type-checking the stms, so we will get an error if a statement
  -- uses an array that is written to in a result.
  (KernelResult -> TypeM rep ()) -> [KernelResult] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelResult -> TypeM rep ()
forall rep. Checkable rep => KernelResult -> TypeM rep ()
consumeKernelResult [KernelResult]
kres
  Stms (Aliases rep) -> TypeM rep () -> TypeM rep ()
forall rep a.
Checkable rep =>
Stms (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.checkStms Stms (Aliases rep)
stms (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ do
    Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
          String
"Kernel return type is "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
ts
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", but body returns "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres)
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" values."
    (KernelResult -> Type -> TypeM rep ())
-> [KernelResult] -> [Type] -> TypeM rep ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ KernelResult -> Type -> TypeM rep ()
forall rep. Checkable rep => KernelResult -> Type -> TypeM rep ()
checkKernelResult [KernelResult]
kres [Type]
ts
  where
    consumeKernelResult :: KernelResult -> TypeM rep ()
consumeKernelResult (WriteReturns Certs
_ Shape
_ VName
arr [(Slice SubExp, SubExp)]
_) =
      Names -> TypeM rep ()
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume (Names -> TypeM rep ()) -> TypeM rep Names -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM rep Names
forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
arr
    consumeKernelResult KernelResult
_ =
      () -> TypeM rep ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    checkKernelResult :: KernelResult -> Type -> TypeM rep ()
checkKernelResult (Returns ResultManifest
_ Certs
cs SubExp
what) Type
t = do
      Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
      [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [Type
t] SubExp
what
    checkKernelResult (WriteReturns Certs
cs Shape
shape VName
arr [(Slice SubExp, SubExp)]
res) Type
t = do
      Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
      (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      Type
arr_t <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      [(Slice SubExp, SubExp)]
-> ((Slice SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Slice SubExp, SubExp)]
res (((Slice SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ())
-> ((Slice SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
e) -> do
        (SubExp -> TypeM rep ()) -> Slice SubExp -> TypeM rep ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Slice SubExp
slice
        [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [Type
t] SubExp
e
        Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> Shape -> Type
`arrayOfShape` Shape
shape) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
            String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
              String
"WriteReturns returning "
                String -> ShowS
forall a. [a] -> [a] -> [a]
++ SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
e
                String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" of type "
                String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
t
                String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", shape="
                String -> ShowS
forall a. [a] -> [a] -> [a]
++ Shape -> String
forall a. Pretty a => a -> String
pretty Shape
shape
                String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", but destination array has type "
                String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
arr_t
    checkKernelResult (ConcatReturns Certs
cs SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) Type
t = do
      Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
      case SplitOrdering
o of
        SplitOrdering
SplitContiguous -> () -> TypeM rep ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        SplitStrided SubExp
stride -> [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
stride
      [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
      [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
per_thread_elems
      Type
vt <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
vt) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
            String
"Invalid type for ConcatReturns " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v
    checkKernelResult (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) Type
t = do
      Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
      [(SubExp, SubExp)]
-> ((SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(SubExp, SubExp)]
dims (((SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ())
-> ((SubExp, SubExp) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
dim, SubExp
tile) -> do
        [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
dim
        [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
tile
      Type
vt <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims)) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
            String
"Invalid type for TileReturns " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v
    checkKernelResult (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
arr) Type
t = do
      Certs -> TypeM rep ()
forall rep. Checkable rep => Certs -> TypeM rep ()
TC.checkCerts Certs
cs
      (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
dims
      (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
blk_tiles
      (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
reg_tiles

      -- assert that arr is of element type t and shape (rev outer_tiles ++ reg_tiles)
      Type
arr_t <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
expected) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (String -> ErrorCase rep) -> String -> TypeM rep ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> TypeM rep ()) -> String -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          String
"Invalid type for TileReturns. Expected:\n  "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
expected
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
",\ngot:\n  "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
arr_t
      where
        ([SubExp]
dims, [SubExp]
blk_tiles, [SubExp]
reg_tiles) = [(SubExp, SubExp, SubExp)] -> ([SubExp], [SubExp], [SubExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, SubExp)]
dims_n_tiles
        expected :: Type
expected = Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp]
blk_tiles [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
reg_tiles)

kernelBodyMetrics :: OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics :: KernelBody rep -> MetricsM ()
kernelBodyMetrics = (Stm rep -> MetricsM ()) -> Seq (Stm rep) -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Stm rep -> MetricsM ()
stmMetrics (Seq (Stm rep) -> MetricsM ())
-> (KernelBody rep -> Seq (Stm rep))
-> KernelBody rep
-> MetricsM ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody rep -> Seq (Stm rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms

instance PrettyRep rep => Pretty (KernelBody rep) where
  ppr :: KernelBody rep -> Doc
ppr (KernelBody BodyDec rep
_ Stms rep
stms [KernelResult]
res) =
    [Doc] -> Doc
PP.stack ((Stm rep -> Doc) -> [Stm rep] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Stm rep -> Doc
forall a. Pretty a => a -> Doc
ppr (Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms))
      Doc -> Doc -> Doc
</> String -> Doc
text String
"return"
      Doc -> Doc -> Doc
<+> Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (KernelResult -> Doc) -> [KernelResult] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Doc
forall a. Pretty a => a -> Doc
ppr [KernelResult]
res)

certAnnots :: Certs -> [PP.Doc]
certAnnots :: Certs -> [Doc]
certAnnots Certs
cs
  | Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty = []
  | Bool
otherwise = [Certs -> Doc
forall a. Pretty a => a -> Doc
ppr Certs
cs]

instance Pretty KernelResult where
  ppr :: KernelResult -> Doc
ppr (Returns ResultManifest
ResultNoSimplify Certs
cs SubExp
what) =
    [Doc] -> Doc
PP.spread ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc]
certAnnots Certs
cs [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [Doc
"returns (manifest)" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what]
  ppr (Returns ResultManifest
ResultPrivate Certs
cs SubExp
what) =
    [Doc] -> Doc
PP.spread ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc]
certAnnots Certs
cs [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [Doc
"returns (private)" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what]
  ppr (Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
what) =
    [Doc] -> Doc
PP.spread ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc]
certAnnots Certs
cs [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [Doc
"returns" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what]
  ppr (WriteReturns Certs
cs Shape
shape VName
arr [(Slice SubExp, SubExp)]
res) =
    [Doc] -> Doc
PP.spread ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$
      Certs -> [Doc]
certAnnots Certs
cs
        [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [ VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
arr
               Doc -> Doc -> Doc
<+> Doc
PP.colon
               Doc -> Doc -> Doc
<+> Shape -> Doc
forall a. Pretty a => a -> Doc
ppr Shape
shape
               Doc -> Doc -> Doc
</> Doc
"with"
               Doc -> Doc -> Doc
<+> [Doc] -> Doc
PP.apply (((Slice SubExp, SubExp) -> Doc)
-> [(Slice SubExp, SubExp)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (Slice SubExp, SubExp) -> Doc
forall a a. (Pretty a, Pretty a) => (a, a) -> Doc
ppRes [(Slice SubExp, SubExp)]
res)
           ]
    where
      ppRes :: (a, a) -> Doc
ppRes (a
slice, a
e) = a -> Doc
forall a. Pretty a => a -> Doc
ppr a
slice Doc -> Doc -> Doc
<+> String -> Doc
text String
"=" Doc -> Doc -> Doc
<+> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
e
  ppr (ConcatReturns Certs
cs SplitOrdering
SplitContiguous SubExp
w SubExp
per_thread_elems VName
v) =
    [Doc] -> Doc
PP.spread ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$
      Certs -> [Doc]
certAnnots Certs
cs
        [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [ Doc
"concat"
               Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
per_thread_elems])
               Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
           ]
  ppr (ConcatReturns Certs
cs (SplitStrided SubExp
stride) SubExp
w SubExp
per_thread_elems VName
v) =
    [Doc] -> Doc
PP.spread ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$
      Certs -> [Doc]
certAnnots Certs
cs
        [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [ Doc
"concat_strided"
               Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
stride, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
per_thread_elems])
               Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
           ]
  ppr (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
v) =
    [Doc] -> Doc
PP.spread ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc]
certAnnots Certs
cs [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [Doc
"tile" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp) -> Doc) -> [(SubExp, SubExp)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> Doc
forall a a. (Pretty a, Pretty a) => (a, a) -> Doc
onDim [(SubExp, SubExp)]
dims) Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v]
    where
      onDim :: (a, a) -> Doc
onDim (a
dim, a
tile) = a -> Doc
forall a. Pretty a => a -> Doc
ppr a
dim Doc -> Doc -> Doc
<+> Doc
"/" Doc -> Doc -> Doc
<+> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
tile
  ppr (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
v) =
    [Doc] -> Doc
PP.spread ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Certs -> [Doc]
certAnnots Certs
cs [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ [Doc
"blkreg_tile" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp, SubExp) -> Doc)
-> [(SubExp, SubExp, SubExp)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp, SubExp) -> Doc
forall a a a. (Pretty a, Pretty a, Pretty a) => (a, a, a) -> Doc
onDim [(SubExp, SubExp, SubExp)]
dims_n_tiles) Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v]
    where
      onDim :: (a, a, a) -> Doc
onDim (a
dim, a
blk_tile, a
reg_tile) =
        a -> Doc
forall a. Pretty a => a -> Doc
ppr a
dim Doc -> Doc -> Doc
<+> Doc
"/" Doc -> Doc -> Doc
<+> Doc -> Doc
parens (a -> Doc
forall a. Pretty a => a -> Doc
ppr a
blk_tile Doc -> Doc -> Doc
<+> Doc
"*" Doc -> Doc -> Doc
<+> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
reg_tile)

-- | These dimensions (indexed from 0, outermost) of the corresponding
-- 'SegSpace' should not be parallelised, but instead iterated
-- sequentially.  For example, with a 'SegSeqDims' of @[0]@ and a
-- 'SegSpace' with dimensions @[n][m]@, there will be an outer loop
-- with @n@ iterations, while the @m@ dimension will be parallelised.
--
-- Semantically, this has no effect, but it may allow reductions in
-- memory usage or other low-level optimisations.  Operationally, the
-- guarantee is that for a SegSeqDims of e.g. @[i,j,k]@, threads
-- running at any given moment will always have the same indexes along
-- the dimensions specified by @[i,j,k]@.
--
-- At the moment, this is only supported for 'SegNoVirtFull'
-- intra-group parallelism in GPU code, as we have not yet found it
-- useful anywhere else.
newtype SegSeqDims = SegSeqDims {SegSeqDims -> [Int]
segSeqDims :: [Int]}
  deriving (SegSeqDims -> SegSeqDims -> Bool
(SegSeqDims -> SegSeqDims -> Bool)
-> (SegSeqDims -> SegSeqDims -> Bool) -> Eq SegSeqDims
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegSeqDims -> SegSeqDims -> Bool
$c/= :: SegSeqDims -> SegSeqDims -> Bool
== :: SegSeqDims -> SegSeqDims -> Bool
$c== :: SegSeqDims -> SegSeqDims -> Bool
Eq, Eq SegSeqDims
Eq SegSeqDims
-> (SegSeqDims -> SegSeqDims -> Ordering)
-> (SegSeqDims -> SegSeqDims -> Bool)
-> (SegSeqDims -> SegSeqDims -> Bool)
-> (SegSeqDims -> SegSeqDims -> Bool)
-> (SegSeqDims -> SegSeqDims -> Bool)
-> (SegSeqDims -> SegSeqDims -> SegSeqDims)
-> (SegSeqDims -> SegSeqDims -> SegSeqDims)
-> Ord SegSeqDims
SegSeqDims -> SegSeqDims -> Bool
SegSeqDims -> SegSeqDims -> Ordering
SegSeqDims -> SegSeqDims -> SegSeqDims
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegSeqDims -> SegSeqDims -> SegSeqDims
$cmin :: SegSeqDims -> SegSeqDims -> SegSeqDims
max :: SegSeqDims -> SegSeqDims -> SegSeqDims
$cmax :: SegSeqDims -> SegSeqDims -> SegSeqDims
>= :: SegSeqDims -> SegSeqDims -> Bool
$c>= :: SegSeqDims -> SegSeqDims -> Bool
> :: SegSeqDims -> SegSeqDims -> Bool
$c> :: SegSeqDims -> SegSeqDims -> Bool
<= :: SegSeqDims -> SegSeqDims -> Bool
$c<= :: SegSeqDims -> SegSeqDims -> Bool
< :: SegSeqDims -> SegSeqDims -> Bool
$c< :: SegSeqDims -> SegSeqDims -> Bool
compare :: SegSeqDims -> SegSeqDims -> Ordering
$ccompare :: SegSeqDims -> SegSeqDims -> Ordering
$cp1Ord :: Eq SegSeqDims
Ord, Int -> SegSeqDims -> ShowS
[SegSeqDims] -> ShowS
SegSeqDims -> String
(Int -> SegSeqDims -> ShowS)
-> (SegSeqDims -> String)
-> ([SegSeqDims] -> ShowS)
-> Show SegSeqDims
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegSeqDims] -> ShowS
$cshowList :: [SegSeqDims] -> ShowS
show :: SegSeqDims -> String
$cshow :: SegSeqDims -> String
showsPrec :: Int -> SegSeqDims -> ShowS
$cshowsPrec :: Int -> SegSeqDims -> ShowS
Show)

-- | Do we need group-virtualisation when generating code for the
-- segmented operation?  In most cases, we do, but for some simple
-- kernels, we compute the full number of groups in advance, and then
-- virtualisation is an unnecessary (but generally very small)
-- overhead.  This only really matters for fairly trivial but very
-- wide @map@ kernels where each thread performs constant-time work on
-- scalars.
data SegVirt
  = SegVirt
  | SegNoVirt
  | -- | Not only do we not need virtualisation, but we _guarantee_
    -- that all physical threads participate in the work.  This can
    -- save some checks in code generation.
    SegNoVirtFull SegSeqDims
  deriving (SegVirt -> SegVirt -> Bool
(SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool) -> Eq SegVirt
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegVirt -> SegVirt -> Bool
$c/= :: SegVirt -> SegVirt -> Bool
== :: SegVirt -> SegVirt -> Bool
$c== :: SegVirt -> SegVirt -> Bool
Eq, Eq SegVirt
Eq SegVirt
-> (SegVirt -> SegVirt -> Ordering)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> SegVirt)
-> (SegVirt -> SegVirt -> SegVirt)
-> Ord SegVirt
SegVirt -> SegVirt -> Bool
SegVirt -> SegVirt -> Ordering
SegVirt -> SegVirt -> SegVirt
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegVirt -> SegVirt -> SegVirt
$cmin :: SegVirt -> SegVirt -> SegVirt
max :: SegVirt -> SegVirt -> SegVirt
$cmax :: SegVirt -> SegVirt -> SegVirt
>= :: SegVirt -> SegVirt -> Bool
$c>= :: SegVirt -> SegVirt -> Bool
> :: SegVirt -> SegVirt -> Bool
$c> :: SegVirt -> SegVirt -> Bool
<= :: SegVirt -> SegVirt -> Bool
$c<= :: SegVirt -> SegVirt -> Bool
< :: SegVirt -> SegVirt -> Bool
$c< :: SegVirt -> SegVirt -> Bool
compare :: SegVirt -> SegVirt -> Ordering
$ccompare :: SegVirt -> SegVirt -> Ordering
$cp1Ord :: Eq SegVirt
Ord, Int -> SegVirt -> ShowS
[SegVirt] -> ShowS
SegVirt -> String
(Int -> SegVirt -> ShowS)
-> (SegVirt -> String) -> ([SegVirt] -> ShowS) -> Show SegVirt
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegVirt] -> ShowS
$cshowList :: [SegVirt] -> ShowS
show :: SegVirt -> String
$cshow :: SegVirt -> String
showsPrec :: Int -> SegVirt -> ShowS
$cshowsPrec :: Int -> SegVirt -> ShowS
Show)

-- | Index space of a 'SegOp'.
data SegSpace = SegSpace
  { -- | Flat physical index corresponding to the
    -- dimensions (at code generation used for a
    -- thread ID or similar).
    SegSpace -> VName
segFlat :: VName,
    SegSpace -> [(VName, SubExp)]
unSegSpace :: [(VName, SubExp)]
  }
  deriving (SegSpace -> SegSpace -> Bool
(SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool) -> Eq SegSpace
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegSpace -> SegSpace -> Bool
$c/= :: SegSpace -> SegSpace -> Bool
== :: SegSpace -> SegSpace -> Bool
$c== :: SegSpace -> SegSpace -> Bool
Eq, Eq SegSpace
Eq SegSpace
-> (SegSpace -> SegSpace -> Ordering)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> SegSpace)
-> (SegSpace -> SegSpace -> SegSpace)
-> Ord SegSpace
SegSpace -> SegSpace -> Bool
SegSpace -> SegSpace -> Ordering
SegSpace -> SegSpace -> SegSpace
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegSpace -> SegSpace -> SegSpace
$cmin :: SegSpace -> SegSpace -> SegSpace
max :: SegSpace -> SegSpace -> SegSpace
$cmax :: SegSpace -> SegSpace -> SegSpace
>= :: SegSpace -> SegSpace -> Bool
$c>= :: SegSpace -> SegSpace -> Bool
> :: SegSpace -> SegSpace -> Bool
$c> :: SegSpace -> SegSpace -> Bool
<= :: SegSpace -> SegSpace -> Bool
$c<= :: SegSpace -> SegSpace -> Bool
< :: SegSpace -> SegSpace -> Bool
$c< :: SegSpace -> SegSpace -> Bool
compare :: SegSpace -> SegSpace -> Ordering
$ccompare :: SegSpace -> SegSpace -> Ordering
$cp1Ord :: Eq SegSpace
Ord, Int -> SegSpace -> ShowS
[SegSpace] -> ShowS
SegSpace -> String
(Int -> SegSpace -> ShowS)
-> (SegSpace -> String) -> ([SegSpace] -> ShowS) -> Show SegSpace
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegSpace] -> ShowS
$cshowList :: [SegSpace] -> ShowS
show :: SegSpace -> String
$cshow :: SegSpace -> String
showsPrec :: Int -> SegSpace -> ShowS
$cshowsPrec :: Int -> SegSpace -> ShowS
Show)

-- | The sizes spanned by the indexes of the 'SegSpace'.
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims (SegSpace VName
_ [(VName, SubExp)]
space) = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
space

-- | A 'Scope' containing all the identifiers brought into scope by
-- this 'SegSpace'.
scopeOfSegSpace :: SegSpace -> Scope rep
scopeOfSegSpace :: SegSpace -> Scope rep
scopeOfSegSpace (SegSpace VName
phys [(VName, SubExp)]
space) =
  [(VName, NameInfo rep)] -> Scope rep
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, NameInfo rep)] -> Scope rep)
-> [(VName, NameInfo rep)] -> Scope rep
forall a b. (a -> b) -> a -> b
$ [VName] -> [NameInfo rep] -> [(VName, NameInfo rep)]
forall a b. [a] -> [b] -> [(a, b)]
zip (VName
phys VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
space) ([NameInfo rep] -> [(VName, NameInfo rep)])
-> [NameInfo rep] -> [(VName, NameInfo rep)]
forall a b. (a -> b) -> a -> b
$ NameInfo rep -> [NameInfo rep]
forall a. a -> [a]
repeat (NameInfo rep -> [NameInfo rep]) -> NameInfo rep -> [NameInfo rep]
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo rep
forall rep. IntType -> NameInfo rep
IndexName IntType
Int64

checkSegSpace :: TC.Checkable rep => SegSpace -> TC.TypeM rep ()
checkSegSpace :: SegSpace -> TypeM rep ()
checkSegSpace (SegSpace VName
_ [(VName, SubExp)]
dims) =
  ((VName, SubExp) -> TypeM rep ())
-> [(VName, SubExp)] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] (SubExp -> TypeM rep ())
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TypeM rep ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
dims

-- | A 'SegOp' is semantically a perfectly nested stack of maps, on
-- top of some bottommost computation (scalar computation, reduction,
-- scan, or histogram).  The 'SegSpace' encodes the original map
-- structure.
--
-- All 'SegOp's are parameterised by the representation of their body,
-- as well as a *level*.  The *level* is a representation-specific bit
-- of information.  For example, in GPU backends, it is used to
-- indicate whether the 'SegOp' is expected to run at the thread-level
-- or the group-level.
data SegOp lvl rep
  = SegMap lvl SegSpace [Type] (KernelBody rep)
  | -- | The KernelSpace must always have at least two dimensions,
    -- implying that the result of a SegRed is always an array.
    SegRed lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep)
  | SegScan lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep)
  | SegHist lvl SegSpace [HistOp rep] [Type] (KernelBody rep)
  deriving (SegOp lvl rep -> SegOp lvl rep -> Bool
(SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool) -> Eq (SegOp lvl rep)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
/= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c/= :: forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
== :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c== :: forall lvl rep.
(RepTypes rep, Eq lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
Eq, Eq (SegOp lvl rep)
Eq (SegOp lvl rep)
-> (SegOp lvl rep -> SegOp lvl rep -> Ordering)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> Bool)
-> (SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep)
-> (SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep)
-> Ord (SegOp lvl rep)
SegOp lvl rep -> SegOp lvl rep -> Bool
SegOp lvl rep -> SegOp lvl rep -> Ordering
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lvl rep. (RepTypes rep, Ord lvl) => Eq (SegOp lvl rep)
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Ordering
forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
min :: SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
$cmin :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
max :: SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
$cmax :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> SegOp lvl rep
>= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c>= :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
> :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c> :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
<= :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c<= :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
< :: SegOp lvl rep -> SegOp lvl rep -> Bool
$c< :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Bool
compare :: SegOp lvl rep -> SegOp lvl rep -> Ordering
$ccompare :: forall lvl rep.
(RepTypes rep, Ord lvl) =>
SegOp lvl rep -> SegOp lvl rep -> Ordering
$cp1Ord :: forall lvl rep. (RepTypes rep, Ord lvl) => Eq (SegOp lvl rep)
Ord, Int -> SegOp lvl rep -> ShowS
[SegOp lvl rep] -> ShowS
SegOp lvl rep -> String
(Int -> SegOp lvl rep -> ShowS)
-> (SegOp lvl rep -> String)
-> ([SegOp lvl rep] -> ShowS)
-> Show (SegOp lvl rep)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall lvl rep.
(RepTypes rep, Show lvl) =>
Int -> SegOp lvl rep -> ShowS
forall lvl rep.
(RepTypes rep, Show lvl) =>
[SegOp lvl rep] -> ShowS
forall lvl rep. (RepTypes rep, Show lvl) => SegOp lvl rep -> String
showList :: [SegOp lvl rep] -> ShowS
$cshowList :: forall lvl rep.
(RepTypes rep, Show lvl) =>
[SegOp lvl rep] -> ShowS
show :: SegOp lvl rep -> String
$cshow :: forall lvl rep. (RepTypes rep, Show lvl) => SegOp lvl rep -> String
showsPrec :: Int -> SegOp lvl rep -> ShowS
$cshowsPrec :: forall lvl rep.
(RepTypes rep, Show lvl) =>
Int -> SegOp lvl rep -> ShowS
Show)

-- | The level of a 'SegOp'.
segLevel :: SegOp lvl rep -> lvl
segLevel :: SegOp lvl rep -> lvl
segLevel (SegMap lvl
lvl SegSpace
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegRed lvl
lvl SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegScan lvl
lvl SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl
segLevel (SegHist lvl
lvl SegSpace
_ [HistOp rep]
_ [Type]
_ KernelBody rep
_) = lvl
lvl

-- | The space of a 'SegOp'.
segSpace :: SegOp lvl rep -> SegSpace
segSpace :: SegOp lvl rep -> SegSpace
segSpace (SegMap lvl
_ SegSpace
lvl [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegRed lvl
_ SegSpace
lvl [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegScan lvl
_ SegSpace
lvl [SegBinOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl
segSpace (SegHist lvl
_ SegSpace
lvl [HistOp rep]
_ [Type]
_ KernelBody rep
_) = SegSpace
lvl

-- | The body of a 'SegOp'.
segBody :: SegOp lvl rep -> KernelBody rep
segBody :: SegOp lvl rep -> KernelBody rep
segBody SegOp lvl rep
segop =
  case SegOp lvl rep
segop of
    SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
    SegRed lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
    SegScan lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body
    SegHist lvl
_ SegSpace
_ [HistOp rep]
_ [Type]
_ KernelBody rep
body -> KernelBody rep
body

segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
_ Type
t (WriteReturns Certs
_ Shape
shape VName
_ [(Slice SubExp, SubExp)]
_) =
  Type
t Type -> Shape -> Type
`arrayOfShape` Shape
shape
segResultShape SegSpace
space Type
t Returns {} =
  (SubExp -> Type -> Type) -> Type -> [SubExp] -> Type
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Type -> SubExp -> Type) -> SubExp -> Type -> Type
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow) Type
t ([SubExp] -> Type) -> [SubExp] -> Type
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segResultShape SegSpace
_ Type
t (ConcatReturns Certs
_ SplitOrdering
_ SubExp
w SubExp
_ VName
_) =
  Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w
segResultShape SegSpace
_ Type
t (TileReturns Certs
_ [(SubExp, SubExp)]
dims VName
_) =
  Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims)
segResultShape SegSpace
_ Type
t (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
_) =
  Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp, SubExp) -> SubExp)
-> [(SubExp, SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (\(SubExp
dim, SubExp
_, SubExp
_) -> SubExp
dim) [(SubExp, SubExp, SubExp)]
dims_n_tiles)

-- | The return type of a 'SegOp'.
segOpType :: SegOp lvl rep -> [Type]
segOpType :: SegOp lvl rep -> [Type]
segOpType (SegMap lvl
_ SegSpace
space [Type]
ts KernelBody rep
kbody) =
  (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
ts ([KernelResult] -> [Type]) -> [KernelResult] -> [Type]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
segOpType (SegRed lvl
_ SegSpace
space [SegBinOp rep]
reds [Type]
ts KernelBody rep
kbody) =
  [Type]
red_ts
    [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
      (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
      [Type]
map_ts
      (Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody)
  where
    map_ts :: [Type]
map_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) [Type]
ts
    segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
    red_ts :: [Type]
red_ts = do
      SegBinOp rep
op <- [SegBinOp rep]
reds
      let shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> SegBinOp rep -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp rep
op
      (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op)
segOpType (SegScan lvl
_ SegSpace
space [SegBinOp rep]
scans [Type]
ts KernelBody rep
kbody) =
  [Type]
scan_ts
    [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
      (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space)
      [Type]
map_ts
      (Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody)
  where
    map_ts :: [Type]
map_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) [Type]
ts
    scan_ts :: [Type]
scan_ts = do
      SegBinOp rep
op <- [SegBinOp rep]
scans
      let shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (SegSpace -> [SubExp]
segSpaceDims SegSpace
space) Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> SegBinOp rep -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp rep
op
      (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op)
segOpType (SegHist lvl
_ SegSpace
space [HistOp rep]
ops [Type]
_ KernelBody rep
_) = do
  HistOp rep
op <- [HistOp rep]
ops
  let shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> HistOp rep -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp rep
op Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> HistOp rep -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp rep
op
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op)
  where
    dims :: [SubExp]
dims = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
    segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
dims

instance TypedOp (SegOp lvl rep) where
  opType :: SegOp lvl rep -> m [ExtType]
opType = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType])
-> (SegOp lvl rep -> [ExtType]) -> SegOp lvl rep -> m [ExtType]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType])
-> (SegOp lvl rep -> [Type]) -> SegOp lvl rep -> [ExtType]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOp lvl rep -> [Type]
forall lvl rep. SegOp lvl rep -> [Type]
segOpType

instance
  (ASTRep rep, Aliased rep, ASTConstraints lvl) =>
  AliasedOp (SegOp lvl rep)
  where
  opAliases :: SegOp lvl rep -> [Names]
opAliases = (Type -> Names) -> [Type] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Type -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([Type] -> [Names])
-> (SegOp lvl rep -> [Type]) -> SegOp lvl rep -> [Names]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOp lvl rep -> [Type]
forall lvl rep. SegOp lvl rep -> [Type]
segOpType

  consumedInOp :: SegOp lvl rep -> Names
consumedInOp (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
kbody) =
    KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
  consumedInOp (SegRed lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
    KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
  consumedInOp (SegScan lvl
_ SegSpace
_ [SegBinOp rep]
_ [Type]
_ KernelBody rep
kbody) =
    KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
  consumedInOp (SegHist lvl
_ SegSpace
_ [HistOp rep]
ops [Type]
_ KernelBody rep
kbody) =
    [VName] -> Names
namesFromList ((HistOp rep -> [VName]) -> [HistOp rep] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest [HistOp rep]
ops) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody

-- | Type check a 'SegOp', given a checker for its level.
typeCheckSegOp ::
  TC.Checkable rep =>
  (lvl -> TC.TypeM rep ()) ->
  SegOp lvl (Aliases rep) ->
  TC.TypeM rep ()
typeCheckSegOp :: (lvl -> TypeM rep ()) -> SegOp lvl (Aliases rep) -> TypeM rep ()
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Aliases rep)
kbody) = do
  lvl -> TypeM rep ()
checkLvl lvl
lvl
  SegSpace
-> [(Lambda (Aliases rep), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [] [Type]
ts KernelBody (Aliases rep)
kbody
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegRed lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
reds [Type]
ts KernelBody (Aliases rep)
body) = do
  lvl -> TypeM rep ()
checkLvl lvl
lvl
  SegSpace
-> [(Lambda (Aliases rep), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], Shape)]
reds' [Type]
ts KernelBody (Aliases rep)
body
  where
    reds' :: [(Lambda (Aliases rep), [SubExp], Shape)]
reds' =
      [Lambda (Aliases rep)]
-> [[SubExp]]
-> [Shape]
-> [(Lambda (Aliases rep), [SubExp], Shape)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
        ((SegBinOp (Aliases rep) -> Lambda (Aliases rep))
-> [SegBinOp (Aliases rep)] -> [Lambda (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> Lambda (Aliases rep)
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp (Aliases rep)]
reds)
        ((SegBinOp (Aliases rep) -> [SubExp])
-> [SegBinOp (Aliases rep)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases rep)]
reds)
        ((SegBinOp (Aliases rep) -> Shape)
-> [SegBinOp (Aliases rep)] -> [Shape]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape [SegBinOp (Aliases rep)]
reds)
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegScan lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
scans [Type]
ts KernelBody (Aliases rep)
body) = do
  lvl -> TypeM rep ()
checkLvl lvl
lvl
  SegSpace
-> [(Lambda (Aliases rep), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
forall rep.
Checkable rep =>
SegSpace
-> [(Lambda (Aliases rep), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], Shape)]
scans' [Type]
ts KernelBody (Aliases rep)
body
  where
    scans' :: [(Lambda (Aliases rep), [SubExp], Shape)]
scans' =
      [Lambda (Aliases rep)]
-> [[SubExp]]
-> [Shape]
-> [(Lambda (Aliases rep), [SubExp], Shape)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
        ((SegBinOp (Aliases rep) -> Lambda (Aliases rep))
-> [SegBinOp (Aliases rep)] -> [Lambda (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> Lambda (Aliases rep)
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp (Aliases rep)]
scans)
        ((SegBinOp (Aliases rep) -> [SubExp])
-> [SegBinOp (Aliases rep)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases rep)]
scans)
        ((SegBinOp (Aliases rep) -> Shape)
-> [SegBinOp (Aliases rep)] -> [Shape]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases rep) -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape [SegBinOp (Aliases rep)]
scans)
typeCheckSegOp lvl -> TypeM rep ()
checkLvl (SegHist lvl
lvl SegSpace
space [HistOp (Aliases rep)]
ops [Type]
ts KernelBody (Aliases rep)
kbody) = do
  lvl -> TypeM rep ()
checkLvl lvl
lvl
  SegSpace -> TypeM rep ()
forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace SegSpace
space
  (Type -> TypeM rep ()) -> [Type] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM rep ()
forall rep u. Checkable rep => TypeBase Shape u -> TypeM rep ()
TC.checkType [Type]
ts

  Scope (Aliases rep) -> TypeM rep () -> TypeM rep ()
forall rep a.
Checkable rep =>
Scope (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.binding (SegSpace -> Scope (Aliases rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ do
    [[Type]]
nes_ts <- [HistOp (Aliases rep)]
-> (HistOp (Aliases rep) -> TypeM rep [Type]) -> TypeM rep [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Aliases rep)]
ops ((HistOp (Aliases rep) -> TypeM rep [Type]) -> TypeM rep [[Type]])
-> (HistOp (Aliases rep) -> TypeM rep [Type]) -> TypeM rep [[Type]]
forall a b. (a -> b) -> a -> b
$ \(HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes Shape
shape Lambda (Aliases rep)
op) -> do
      (SubExp -> TypeM rep ()) -> Shape -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Shape
dest_shape
      [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
rf
      [Arg]
nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes
      (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape

      -- Operator type must match the type of neutral elements.
      let stripVecDims :: Type -> Type
stripVecDims = Int -> Type -> Type
forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray (Int -> Type -> Type) -> Int -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape
      Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
op ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map (Arg -> Arg
TC.noArgAliases (Arg -> Arg) -> (Arg -> Arg) -> Arg -> Arg
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Type -> Type) -> Arg -> Arg
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Type -> Type
stripVecDims) ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
      let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
            String
"SegHist operator has return type "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op)
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
nes_t

      -- Arrays must have proper type.
      let dest_shape' :: Shape
dest_shape' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
dest_shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape
      [(Type, VName)] -> ((Type, VName) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) (((Type, VName) -> TypeM rep ()) -> TypeM rep ())
-> ((Type, VName) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
        [Type] -> VName -> TypeM rep ()
forall rep. Checkable rep => [Type] -> VName -> TypeM rep ()
TC.requireI [Type
t Type -> Shape -> Type
`arrayOfShape` Shape
dest_shape'] VName
dest
        Names -> TypeM rep ()
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume (Names -> TypeM rep ()) -> TypeM rep Names -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM rep Names
forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
dest

      [Type] -> TypeM rep [Type]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type] -> TypeM rep [Type]) -> [Type] -> TypeM rep [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) [Type]
nes_t

    [Type] -> KernelBody (Aliases rep) -> TypeM rep ()
forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts KernelBody (Aliases rep)
kbody

    -- Return type of bucket function must be an index for each
    -- operation followed by the values to write.
    let bucket_ret_t :: [Type]
bucket_ret_t =
          (HistOp (Aliases rep) -> [Type])
-> [HistOp (Aliases rep)] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Int -> Type -> [Type]
forall a. Int -> a -> [a]
`replicate` PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) (Int -> [Type])
-> (HistOp (Aliases rep) -> Int) -> HistOp (Aliases rep) -> [Type]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (Shape -> Int)
-> (HistOp (Aliases rep) -> Shape) -> HistOp (Aliases rep) -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp (Aliases rep) -> Shape
forall rep. HistOp rep -> Shape
histShape) [HistOp (Aliases rep)]
ops
            [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
nes_ts
    Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
ts) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
          String
"SegHist body has return type "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
ts
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but should have type "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
bucket_ret_t
  where
    segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space

checkScanRed ::
  TC.Checkable rep =>
  SegSpace ->
  [(Lambda (Aliases rep), [SubExp], Shape)] ->
  [Type] ->
  KernelBody (Aliases rep) ->
  TC.TypeM rep ()
checkScanRed :: SegSpace
-> [(Lambda (Aliases rep), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases rep)
-> TypeM rep ()
checkScanRed SegSpace
space [(Lambda (Aliases rep), [SubExp], Shape)]
ops [Type]
ts KernelBody (Aliases rep)
kbody = do
  SegSpace -> TypeM rep ()
forall rep. Checkable rep => SegSpace -> TypeM rep ()
checkSegSpace SegSpace
space
  (Type -> TypeM rep ()) -> [Type] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM rep ()
forall rep u. Checkable rep => TypeBase Shape u -> TypeM rep ()
TC.checkType [Type]
ts

  Scope (Aliases rep) -> TypeM rep () -> TypeM rep ()
forall rep a.
Checkable rep =>
Scope (Aliases rep) -> TypeM rep a -> TypeM rep a
TC.binding (SegSpace -> Scope (Aliases rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ do
    [[Type]]
ne_ts <- [(Lambda (Aliases rep), [SubExp], Shape)]
-> ((Lambda (Aliases rep), [SubExp], Shape) -> TypeM rep [Type])
-> TypeM rep [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Lambda (Aliases rep), [SubExp], Shape)]
ops (((Lambda (Aliases rep), [SubExp], Shape) -> TypeM rep [Type])
 -> TypeM rep [[Type]])
-> ((Lambda (Aliases rep), [SubExp], Shape) -> TypeM rep [Type])
-> TypeM rep [[Type]]
forall a b. (a -> b) -> a -> b
$ \(Lambda (Aliases rep)
lam, [SubExp]
nes, Shape
shape) -> do
      (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) ([SubExp] -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      [Arg]
nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes

      -- Operator type must match the type of neutral elements.
      Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
      let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'

      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
nes_t) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError String
"wrong type for operator or neutral elements."

      [Type] -> TypeM rep [Type]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type] -> TypeM rep [Type]) -> [Type] -> TypeM rep [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) [Type]
nes_t

    let expecting :: [Type]
expecting = [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
ne_ts
        got :: [Type]
got = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
expecting) [Type]
ts
    Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
expecting [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
got) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
          String
"Wrong return for body (does not match neutral elements; expected "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => a -> String
pretty [Type]
expecting
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"; found "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => a -> String
pretty [Type]
got
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

    [Type] -> KernelBody (Aliases rep) -> TypeM rep ()
forall rep.
Checkable rep =>
[Type] -> KernelBody (Aliases rep) -> TypeM rep ()
checkKernelBody [Type]
ts KernelBody (Aliases rep)
kbody

-- | Like 'Mapper', but just for 'SegOp's.
data SegOpMapper lvl frep trep m = SegOpMapper
  { SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp :: SubExp -> m SubExp,
    SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda :: Lambda frep -> m (Lambda trep),
    SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody :: KernelBody frep -> m (KernelBody trep),
    SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName :: VName -> m VName,
    SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel :: lvl -> m lvl
  }

-- | A mapper that simply returns the 'SegOp' verbatim.
identitySegOpMapper :: Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper :: SegOpMapper lvl rep rep m
identitySegOpMapper =
  SegOpMapper :: forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
    { mapOnSegOpSubExp :: SubExp -> m SubExp
mapOnSegOpSubExp = SubExp -> m SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSegOpLambda :: Lambda rep -> m (Lambda rep)
mapOnSegOpLambda = Lambda rep -> m (Lambda rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSegOpBody :: KernelBody rep -> m (KernelBody rep)
mapOnSegOpBody = KernelBody rep -> m (KernelBody rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSegOpVName :: VName -> m VName
mapOnSegOpVName = VName -> m VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSegOpLevel :: lvl -> m lvl
mapOnSegOpLevel = lvl -> m lvl
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    }

mapOnSegSpace ::
  Monad f => SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace :: SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep f
tv (SegSpace VName
phys [(VName, SubExp)]
dims) =
  VName -> [(VName, SubExp)] -> SegSpace
SegSpace
    (VName -> [(VName, SubExp)] -> SegSpace)
-> f VName -> f ([(VName, SubExp)] -> SegSpace)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep f -> VName -> f VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep f
tv VName
phys
    f ([(VName, SubExp)] -> SegSpace)
-> f [(VName, SubExp)] -> f SegSpace
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((VName, SubExp) -> f (VName, SubExp))
-> [(VName, SubExp)] -> f [(VName, SubExp)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((VName -> f VName)
-> (SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse (SegOpMapper lvl frep trep f -> VName -> f VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep f
tv) (SegOpMapper lvl frep trep f -> SubExp -> f SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep f
tv)) [(VName, SubExp)]
dims

mapSegBinOp ::
  Monad m =>
  SegOpMapper lvl frep trep m ->
  SegBinOp frep ->
  m (SegBinOp trep)
mapSegBinOp :: SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv (SegBinOp Commutativity
comm Lambda frep
red_op [SubExp]
nes Shape
shape) =
  Commutativity -> Lambda trep -> [SubExp] -> Shape -> SegBinOp trep
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm
    (Lambda trep -> [SubExp] -> Shape -> SegBinOp trep)
-> m (Lambda trep) -> m ([SubExp] -> Shape -> SegBinOp trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda SegOpMapper lvl frep trep m
tv Lambda frep
red_op
    m ([SubExp] -> Shape -> SegBinOp trep)
-> m [SubExp] -> m (Shape -> SegBinOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [SubExp]
nes
    m (Shape -> SegBinOp trep) -> m Shape -> m (SegBinOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> m [SubExp] -> m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape))

-- | Apply a 'SegOpMapper' to the given 'SegOp'.
mapSegOpM ::
  (Applicative m, Monad m) =>
  SegOpMapper lvl frep trep m ->
  SegOp lvl frep ->
  m (SegOp lvl trep)
mapSegOpM :: SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl frep trep m
tv (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody frep
body) =
  lvl -> SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap
    (lvl -> SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m lvl
-> m (SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
    m (SegSpace -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace -> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
    m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl frep trep m -> Type -> m Type
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl frep trep m
tv) [Type]
ts
    m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
mapSegOpM SegOpMapper lvl frep trep m
tv (SegRed lvl
lvl SegSpace
space [SegBinOp frep]
reds [Type]
ts KernelBody frep
lam) =
  lvl
-> SegSpace
-> [SegBinOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed
    (lvl
 -> SegSpace
 -> [SegBinOp trep]
 -> [Type]
 -> KernelBody trep
 -> SegOp lvl trep)
-> m lvl
-> m (SegSpace
      -> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
    m (SegSpace
   -> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace
-> m ([SegBinOp trep]
      -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
    m ([SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m [SegBinOp trep]
-> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SegBinOp frep -> m (SegBinOp trep))
-> [SegBinOp frep] -> m [SegBinOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv) [SegBinOp frep]
reds
    m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp) -> TypeBase Shape u -> m (TypeBase Shape u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
    m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
lam
mapSegOpM SegOpMapper lvl frep trep m
tv (SegScan lvl
lvl SegSpace
space [SegBinOp frep]
scans [Type]
ts KernelBody frep
body) =
  lvl
-> SegSpace
-> [SegBinOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan
    (lvl
 -> SegSpace
 -> [SegBinOp trep]
 -> [Type]
 -> KernelBody trep
 -> SegOp lvl trep)
-> m lvl
-> m (SegSpace
      -> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
    m (SegSpace
   -> [SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace
-> m ([SegBinOp trep]
      -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
    m ([SegBinOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m [SegBinOp trep]
-> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SegBinOp frep -> m (SegBinOp trep))
-> [SegBinOp frep] -> m [SegBinOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep)
mapSegBinOp SegOpMapper lvl frep trep m
tv) [SegBinOp frep]
scans
    m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp) -> TypeBase Shape u -> m (TypeBase Shape u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
    m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
mapSegOpM SegOpMapper lvl frep trep m
tv (SegHist lvl
lvl SegSpace
space [HistOp frep]
ops [Type]
ts KernelBody frep
body) =
  lvl
-> SegSpace
-> [HistOp trep]
-> [Type]
-> KernelBody trep
-> SegOp lvl trep
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist
    (lvl
 -> SegSpace
 -> [HistOp trep]
 -> [Type]
 -> KernelBody trep
 -> SegOp lvl trep)
-> m lvl
-> m (SegSpace
      -> [HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> lvl -> m lvl
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl frep trep m
tv lvl
lvl
    m (SegSpace
   -> [HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m SegSpace
-> m ([HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl frep trep.
Monad f =>
SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl frep trep m
tv SegSpace
space
    m ([HistOp trep] -> [Type] -> KernelBody trep -> SegOp lvl trep)
-> m [HistOp trep]
-> m ([Type] -> KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (HistOp frep -> m (HistOp trep))
-> [HistOp frep] -> m [HistOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp frep -> m (HistOp trep)
onHistOp [HistOp frep]
ops
    m ([Type] -> KernelBody trep -> SegOp lvl trep)
-> m [Type] -> m (KernelBody trep -> SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp) -> TypeBase Shape u -> m (TypeBase Shape u)
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [Type]
ts
    m (KernelBody trep -> SegOp lvl trep)
-> m (KernelBody trep) -> m (SegOp lvl trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m
-> KernelBody frep -> m (KernelBody trep)
mapOnSegOpBody SegOpMapper lvl frep trep m
tv KernelBody frep
body
  where
    onHistOp :: HistOp frep -> m (HistOp trep)
onHistOp (HistOp Shape
w SubExp
rf [VName]
arrs [SubExp]
nes Shape
shape Lambda frep
op) =
      Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda trep
-> HistOp trep
forall rep.
Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda rep
-> HistOp rep
HistOp
        (Shape
 -> SubExp
 -> [VName]
 -> [SubExp]
 -> Shape
 -> Lambda trep
 -> HistOp trep)
-> m Shape
-> m (SubExp
      -> [VName] -> [SubExp] -> Shape -> Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> Shape -> m Shape
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) Shape
w
        m (SubExp
   -> [VName] -> [SubExp] -> Shape -> Lambda trep -> HistOp trep)
-> m SubExp
-> m ([VName] -> [SubExp] -> Shape -> Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv SubExp
rf
        m ([VName] -> [SubExp] -> Shape -> Lambda trep -> HistOp trep)
-> m [VName] -> m ([SubExp] -> Shape -> Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl frep trep m -> VName -> m VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep m
tv) [VName]
arrs
        m ([SubExp] -> Shape -> Lambda trep -> HistOp trep)
-> m [SubExp] -> m (Shape -> Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) [SubExp]
nes
        m (Shape -> Lambda trep -> HistOp trep)
-> m Shape -> m (Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> m [SubExp] -> m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape))
        m (Lambda trep -> HistOp trep)
-> m (Lambda trep) -> m (HistOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSegOpLambda SegOpMapper lvl frep trep m
tv Lambda frep
op

mapOnSegOpType ::
  Monad m =>
  SegOpMapper lvl frep trep m ->
  Type ->
  m Type
mapOnSegOpType :: SegOpMapper lvl frep trep m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl frep trep m
_tv t :: Type
t@Prim {} = Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t
mapOnSegOpType SegOpMapper lvl frep trep m
tv (Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) =
  VName -> Shape -> [Type] -> NoUniqueness -> Type
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc
    (VName -> Shape -> [Type] -> NoUniqueness -> Type)
-> m VName -> m (Shape -> [Type] -> NoUniqueness -> Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl frep trep m -> VName -> m VName
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl frep trep m
tv VName
acc
    m (Shape -> [Type] -> NoUniqueness -> Type)
-> m Shape -> m ([Type] -> NoUniqueness -> Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> Shape -> m Shape
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) Shape
ispace
    m ([Type] -> NoUniqueness -> Type)
-> m [Type] -> m (NoUniqueness -> Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Shape -> m Shape)
-> (NoUniqueness -> m NoUniqueness) -> Type -> m Type
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse ((SubExp -> m SubExp) -> Shape -> m Shape
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv)) NoUniqueness -> m NoUniqueness
forall (f :: * -> *) a. Applicative f => a -> f a
pure) [Type]
ts
    m (NoUniqueness -> Type) -> m NoUniqueness -> m Type
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NoUniqueness -> m NoUniqueness
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
mapOnSegOpType SegOpMapper lvl frep trep m
tv (Array PrimType
et Shape
shape NoUniqueness
u) =
  PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et (Shape -> NoUniqueness -> Type)
-> m Shape -> m (NoUniqueness -> Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> Shape -> m Shape
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper lvl frep trep m -> SubExp -> m SubExp
forall lvl frep trep (m :: * -> *).
SegOpMapper lvl frep trep m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl frep trep m
tv) Shape
shape m (NoUniqueness -> Type) -> m NoUniqueness -> m Type
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NoUniqueness -> m NoUniqueness
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
mapOnSegOpType SegOpMapper lvl frep trep m
_tv (Mem Space
s) = Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
s

-- | A helper for defining 'TraverseOpStms'.
traverseSegOpStms :: Monad m => OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms :: OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms Scope rep -> Stms rep -> m (Stms rep)
f SegOp lvl rep
segop = SegOpMapper lvl rep rep m -> SegOp lvl rep -> m (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep m
mapper SegOp lvl rep
segop
  where
    seg_scope :: Scope rep
seg_scope = SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp lvl rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
segop)
    f' :: Scope rep -> Stms rep -> m (Stms rep)
f' Scope rep
scope = Scope rep -> Stms rep -> m (Stms rep)
f (Scope rep
seg_scope Scope rep -> Scope rep -> Scope rep
forall a. Semigroup a => a -> a -> a
<> Scope rep
scope)
    mapper :: SegOpMapper lvl rep rep m
mapper =
      SegOpMapper lvl Any Any m
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda rep -> m (Lambda rep)
mapOnSegOpLambda = OpStmsTraverser m (Lambda rep) rep
forall (m :: * -> *) rep.
Monad m =>
OpStmsTraverser m (Lambda rep) rep
traverseLambdaStms Scope rep -> Stms rep -> m (Stms rep)
f',
          mapOnSegOpBody :: KernelBody rep -> m (KernelBody rep)
mapOnSegOpBody = KernelBody rep -> m (KernelBody rep)
onBody
        }
    onBody :: KernelBody rep -> m (KernelBody rep)
onBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
      BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
dec (Stms rep -> [KernelResult] -> KernelBody rep)
-> m (Stms rep) -> m ([KernelResult] -> KernelBody rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope rep -> Stms rep -> m (Stms rep)
f Scope rep
seg_scope Stms rep
stms m ([KernelResult] -> KernelBody rep)
-> m [KernelResult] -> m (KernelBody rep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> m [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

instance
  (ASTRep rep, Substitute lvl) =>
  Substitute (SegOp lvl rep)
  where
  substituteNames :: Map VName VName -> SegOp lvl rep -> SegOp lvl rep
substituteNames Map VName VName
subst = Identity (SegOp lvl rep) -> SegOp lvl rep
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl rep) -> SegOp lvl rep)
-> (SegOp lvl rep -> Identity (SegOp lvl rep))
-> SegOp lvl rep
-> SegOp lvl rep
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl rep rep Identity
-> SegOp lvl rep -> Identity (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep Identity
substitute
    where
      substitute :: SegOpMapper lvl rep rep Identity
substitute =
        SegOpMapper :: forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
          { mapOnSegOpSubExp :: SubExp -> Identity SubExp
mapOnSegOpSubExp = SubExp -> Identity SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpLambda :: Lambda rep -> Identity (Lambda rep)
mapOnSegOpLambda = Lambda rep -> Identity (Lambda rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep -> Identity (Lambda rep))
-> (Lambda rep -> Lambda rep)
-> Lambda rep
-> Identity (Lambda rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> Lambda rep -> Lambda rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpBody :: KernelBody rep -> Identity (KernelBody rep)
mapOnSegOpBody = KernelBody rep -> Identity (KernelBody rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody rep -> Identity (KernelBody rep))
-> (KernelBody rep -> KernelBody rep)
-> KernelBody rep
-> Identity (KernelBody rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> KernelBody rep -> KernelBody rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpVName :: VName -> Identity VName
mapOnSegOpVName = VName -> Identity VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> Identity VName)
-> (VName -> VName) -> VName -> Identity VName
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSegOpLevel :: lvl -> Identity lvl
mapOnSegOpLevel = lvl -> Identity lvl
forall (f :: * -> *) a. Applicative f => a -> f a
pure (lvl -> Identity lvl) -> (lvl -> lvl) -> lvl -> Identity lvl
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> lvl -> lvl
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
          }

instance (ASTRep rep, ASTConstraints lvl) => Rename (SegOp lvl rep) where
  rename :: SegOp lvl rep -> RenameM (SegOp lvl rep)
rename SegOp lvl rep
op =
    [VName] -> RenameM (SegOp lvl rep) -> RenameM (SegOp lvl rep)
forall a. [VName] -> RenameM a -> RenameM a
renameBound (Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp lvl rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
op))) (RenameM (SegOp lvl rep) -> RenameM (SegOp lvl rep))
-> RenameM (SegOp lvl rep) -> RenameM (SegOp lvl rep)
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl rep rep RenameM
-> SegOp lvl rep -> RenameM (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep RenameM
renamer SegOp lvl rep
op
    where
      renamer :: SegOpMapper lvl rep rep RenameM
renamer = (SubExp -> RenameM SubExp)
-> (Lambda rep -> RenameM (Lambda rep))
-> (KernelBody rep -> RenameM (KernelBody rep))
-> (VName -> RenameM VName)
-> (lvl -> RenameM lvl)
-> SegOpMapper lvl rep rep RenameM
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename Lambda rep -> RenameM (Lambda rep)
forall a. Rename a => a -> RenameM a
rename KernelBody rep -> RenameM (KernelBody rep)
forall a. Rename a => a -> RenameM a
rename VName -> RenameM VName
forall a. Rename a => a -> RenameM a
rename lvl -> RenameM lvl
forall a. Rename a => a -> RenameM a
rename

instance
  (ASTRep rep, FreeIn (LParamInfo rep), FreeIn lvl) =>
  FreeIn (SegOp lvl rep)
  where
  freeIn' :: SegOp lvl rep -> FV
freeIn' SegOp lvl rep
e =
    Names -> FV -> FV
fvBind ([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp lvl rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
e)) (FV -> FV) -> FV -> FV
forall a b. (a -> b) -> a -> b
$
      (State FV (SegOp lvl rep) -> FV -> FV)
-> FV -> State FV (SegOp lvl rep) -> FV
forall a b c. (a -> b -> c) -> b -> a -> c
flip State FV (SegOp lvl rep) -> FV -> FV
forall s a. State s a -> s -> s
execState FV
forall a. Monoid a => a
mempty (State FV (SegOp lvl rep) -> FV) -> State FV (SegOp lvl rep) -> FV
forall a b. (a -> b) -> a -> b
$
        SegOpMapper lvl rep rep (StateT FV Identity)
-> SegOp lvl rep -> State FV (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep (StateT FV Identity)
free SegOp lvl rep
e
    where
      walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = (s -> s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (s -> s -> s
forall a. Semigroup a => a -> a -> a
<> b -> s
f b
x) m () -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
x
      free :: SegOpMapper lvl rep rep (StateT FV Identity)
free =
        SegOpMapper :: forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
          { mapOnSegOpSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSegOpSubExp = (SubExp -> FV) -> SubExp -> StateT FV Identity SubExp
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpLambda :: Lambda rep -> StateT FV Identity (Lambda rep)
mapOnSegOpLambda = (Lambda rep -> FV) -> Lambda rep -> StateT FV Identity (Lambda rep)
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk Lambda rep -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpBody :: KernelBody rep -> StateT FV Identity (KernelBody rep)
mapOnSegOpBody = (KernelBody rep -> FV)
-> KernelBody rep -> StateT FV Identity (KernelBody rep)
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk KernelBody rep -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpVName :: VName -> StateT FV Identity VName
mapOnSegOpVName = (VName -> FV) -> VName -> StateT FV Identity VName
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk VName -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSegOpLevel :: lvl -> StateT FV Identity lvl
mapOnSegOpLevel = (lvl -> FV) -> lvl -> StateT FV Identity lvl
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk lvl -> FV
forall a. FreeIn a => a -> FV
freeIn'
          }

instance OpMetrics (Op rep) => OpMetrics (SegOp lvl rep) where
  opMetrics :: SegOp lvl rep -> MetricsM ()
opMetrics (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody rep
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegMap" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
  opMetrics (SegRed lvl
_ SegSpace
_ [SegBinOp rep]
reds [Type]
_ KernelBody rep
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegRed" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
      (SegBinOp rep -> MetricsM ()) -> [SegBinOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (SegBinOp rep -> Lambda rep) -> SegBinOp rep -> MetricsM ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp rep]
reds
      KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
  opMetrics (SegScan lvl
_ SegSpace
_ [SegBinOp rep]
scans [Type]
_ KernelBody rep
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegScan" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
      (SegBinOp rep -> MetricsM ()) -> [SegBinOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (SegBinOp rep -> Lambda rep) -> SegBinOp rep -> MetricsM ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp rep]
scans
      KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body
  opMetrics (SegHist lvl
_ SegSpace
_ [HistOp rep]
ops [Type]
_ KernelBody rep
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegHist" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
      (HistOp rep -> MetricsM ()) -> [HistOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (HistOp rep -> Lambda rep) -> HistOp rep -> MetricsM ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp) [HistOp rep]
ops
      KernelBody rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => KernelBody rep -> MetricsM ()
kernelBodyMetrics KernelBody rep
body

instance Pretty SegSpace where
  ppr :: SegSpace -> Doc
ppr (SegSpace VName
phys [(VName, SubExp)]
dims) =
    Doc -> Doc
parens
      ( [Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ do
          (VName
i, SubExp
d) <- [(VName, SubExp)]
dims
          Doc -> [Doc]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Doc -> [Doc]) -> Doc -> [Doc]
forall a b. (a -> b) -> a -> b
$ VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
i Doc -> Doc -> Doc
<+> Doc
"<" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
d
      )
      Doc -> Doc -> Doc
<+> Doc -> Doc
parens (String -> Doc
text String
"~" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
phys)

instance PrettyRep rep => Pretty (SegBinOp rep) where
  ppr :: SegBinOp rep -> Doc
ppr (SegBinOp Commutativity
comm Lambda rep
lam [SubExp]
nes Shape
shape) =
    Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
      Doc -> Doc -> Doc
</> Shape -> Doc
forall a. Pretty a => a -> Doc
ppr Shape
shape Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
      Doc -> Doc -> Doc
</> Doc
comm' Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
lam
    where
      comm' :: Doc
comm' = case Commutativity
comm of
        Commutativity
Commutative -> String -> Doc
text String
"commutative "
        Commutativity
Noncommutative -> Doc
forall a. Monoid a => a
mempty

instance (PrettyRep rep, PP.Pretty lvl) => PP.Pretty (SegOp lvl rep) where
  ppr :: SegOp lvl rep -> Doc
ppr (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody rep
body) =
    String -> Doc
text String
"segmap" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl
      Doc -> Doc -> Doc
</> Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space)
      Doc -> Doc -> Doc
<+> Doc
PP.colon
      Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts
      Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody rep -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody rep
body)
  ppr (SegRed lvl
lvl SegSpace
space [SegBinOp rep]
reds [Type]
ts KernelBody rep
body) =
    String -> Doc
text String
"segred" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl
      Doc -> Doc -> Doc
</> Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space)
      Doc -> Doc -> Doc
</> Doc -> Doc
PP.parens ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (SegBinOp rep -> Doc) -> [SegBinOp rep] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp rep -> Doc
forall a. Pretty a => a -> Doc
ppr [SegBinOp rep]
reds)
      Doc -> Doc -> Doc
</> Doc
PP.colon
      Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts
      Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody rep -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody rep
body)
  ppr (SegScan lvl
lvl SegSpace
space [SegBinOp rep]
scans [Type]
ts KernelBody rep
body) =
    String -> Doc
text String
"segscan" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl
      Doc -> Doc -> Doc
</> Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space)
      Doc -> Doc -> Doc
</> Doc -> Doc
PP.parens ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (SegBinOp rep -> Doc) -> [SegBinOp rep] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp rep -> Doc
forall a. Pretty a => a -> Doc
ppr [SegBinOp rep]
scans)
      Doc -> Doc -> Doc
</> Doc
PP.colon
      Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts
      Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody rep -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody rep
body)
  ppr (SegHist lvl
lvl SegSpace
space [HistOp rep]
ops [Type]
ts KernelBody rep
body) =
    String -> Doc
text String
"seghist" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl
      Doc -> Doc -> Doc
</> Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space)
      Doc -> Doc -> Doc
</> Doc -> Doc
PP.parens ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> Doc) -> [HistOp rep] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map HistOp rep -> Doc
forall rep. PrettyRep rep => HistOp rep -> Doc
ppOp [HistOp rep]
ops)
      Doc -> Doc -> Doc
</> Doc
PP.colon
      Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts
      Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody rep -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody rep
body)
    where
      ppOp :: HistOp rep -> Doc
ppOp (HistOp Shape
w SubExp
rf [VName]
dests [SubExp]
nes Shape
shape Lambda rep
op) =
        Shape -> Doc
forall a. Pretty a => a -> Doc
ppr Shape
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
          Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
rf Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
          Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
ppr [VName]
dests) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
          Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
          Doc -> Doc -> Doc
</> Shape -> Doc
forall a. Pretty a => a -> Doc
ppr Shape
shape Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma
          Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
op

instance
  ( ASTRep rep,
    ASTRep (Aliases rep),
    CanBeAliased (Op rep),
    ASTConstraints lvl
  ) =>
  CanBeAliased (SegOp lvl rep)
  where
  type OpWithAliases (SegOp lvl rep) = SegOp lvl (Aliases rep)

  addOpAliases :: AliasTable -> SegOp lvl rep -> OpWithAliases (SegOp lvl rep)
addOpAliases AliasTable
aliases = Identity (SegOp lvl (Aliases rep)) -> SegOp lvl (Aliases rep)
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl (Aliases rep)) -> SegOp lvl (Aliases rep))
-> (SegOp lvl rep -> Identity (SegOp lvl (Aliases rep)))
-> SegOp lvl rep
-> SegOp lvl (Aliases rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl rep (Aliases rep) Identity
-> SegOp lvl rep -> Identity (SegOp lvl (Aliases rep))
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep (Aliases rep) Identity
alias
    where
      alias :: SegOpMapper lvl rep (Aliases rep) Identity
alias =
        (SubExp -> Identity SubExp)
-> (Lambda rep -> Identity (Lambda (Aliases rep)))
-> (KernelBody rep -> Identity (KernelBody (Aliases rep)))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl rep (Aliases rep) Identity
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
          SubExp -> Identity SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          (Lambda (Aliases rep) -> Identity (Lambda (Aliases rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Aliases rep) -> Identity (Lambda (Aliases rep)))
-> (Lambda rep -> Lambda (Aliases rep))
-> Lambda rep
-> Identity (Lambda (Aliases rep))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases)
          (KernelBody (Aliases rep) -> Identity (KernelBody (Aliases rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody (Aliases rep) -> Identity (KernelBody (Aliases rep)))
-> (KernelBody rep -> KernelBody (Aliases rep))
-> KernelBody rep
-> Identity (KernelBody (Aliases rep))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> KernelBody rep -> KernelBody (Aliases rep)
aliasAnalyseKernelBody AliasTable
aliases)
          VName -> Identity VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          lvl -> Identity lvl
forall (f :: * -> *) a. Applicative f => a -> f a
pure

  removeOpAliases :: OpWithAliases (SegOp lvl rep) -> SegOp lvl rep
removeOpAliases = Identity (SegOp lvl rep) -> SegOp lvl rep
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl rep) -> SegOp lvl rep)
-> (SegOp lvl (Aliases rep) -> Identity (SegOp lvl rep))
-> SegOp lvl (Aliases rep)
-> SegOp lvl rep
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl (Aliases rep) rep Identity
-> SegOp lvl (Aliases rep) -> Identity (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl (Aliases rep) rep Identity
forall lvl. SegOpMapper lvl (Aliases rep) rep Identity
remove
    where
      remove :: SegOpMapper lvl (Aliases rep) rep Identity
remove =
        (SubExp -> Identity SubExp)
-> (Lambda (Aliases rep) -> Identity (Lambda rep))
-> (KernelBody (Aliases rep) -> Identity (KernelBody rep))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl (Aliases rep) rep Identity
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
          SubExp -> Identity SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          (Lambda rep -> Identity (Lambda rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep -> Identity (Lambda rep))
-> (Lambda (Aliases rep) -> Lambda rep)
-> Lambda (Aliases rep)
-> Identity (Lambda rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda (Aliases rep) -> Lambda rep
forall rep.
CanBeAliased (Op rep) =>
Lambda (Aliases rep) -> Lambda rep
removeLambdaAliases)
          (KernelBody rep -> Identity (KernelBody rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody rep -> Identity (KernelBody rep))
-> (KernelBody (Aliases rep) -> KernelBody rep)
-> KernelBody (Aliases rep)
-> Identity (KernelBody rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody (Aliases rep) -> KernelBody rep
forall rep.
CanBeAliased (Op rep) =>
KernelBody (Aliases rep) -> KernelBody rep
removeKernelBodyAliases)
          VName -> Identity VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          lvl -> Identity lvl
forall (f :: * -> *) a. Applicative f => a -> f a
pure

informKernelBody :: Informing rep => KernelBody rep -> KernelBody (Wise rep)
informKernelBody :: KernelBody rep -> KernelBody (Wise rep)
informKernelBody (KernelBody BodyDec rep
dec Stms rep
stms [KernelResult]
res) =
  BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody BodyDec rep
dec (Stms rep -> Stms (Wise rep)
forall rep. Informing rep => Stms rep -> Stms (Wise rep)
informStms Stms rep
stms) [KernelResult]
res

instance
  (CanBeWise (Op rep), ASTRep rep, ASTConstraints lvl) =>
  CanBeWise (SegOp lvl rep)
  where
  type OpWithWisdom (SegOp lvl rep) = SegOp lvl (Wise rep)

  removeOpWisdom :: OpWithWisdom (SegOp lvl rep) -> SegOp lvl rep
removeOpWisdom = Identity (SegOp lvl rep) -> SegOp lvl rep
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl rep) -> SegOp lvl rep)
-> (SegOp lvl (Wise rep) -> Identity (SegOp lvl rep))
-> SegOp lvl (Wise rep)
-> SegOp lvl rep
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl (Wise rep) rep Identity
-> SegOp lvl (Wise rep) -> Identity (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl (Wise rep) rep Identity
forall lvl. SegOpMapper lvl (Wise rep) rep Identity
remove
    where
      remove :: SegOpMapper lvl (Wise rep) rep Identity
remove =
        (SubExp -> Identity SubExp)
-> (Lambda (Wise rep) -> Identity (Lambda rep))
-> (KernelBody (Wise rep) -> Identity (KernelBody rep))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl (Wise rep) rep Identity
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
          SubExp -> Identity SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          (Lambda rep -> Identity (Lambda rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep -> Identity (Lambda rep))
-> (Lambda (Wise rep) -> Lambda rep)
-> Lambda (Wise rep)
-> Identity (Lambda rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda (Wise rep) -> Lambda rep
forall rep. CanBeWise (Op rep) => Lambda (Wise rep) -> Lambda rep
removeLambdaWisdom)
          (KernelBody rep -> Identity (KernelBody rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody rep -> Identity (KernelBody rep))
-> (KernelBody (Wise rep) -> KernelBody rep)
-> KernelBody (Wise rep)
-> Identity (KernelBody rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody (Wise rep) -> KernelBody rep
forall rep.
CanBeWise (Op rep) =>
KernelBody (Wise rep) -> KernelBody rep
removeKernelBodyWisdom)
          VName -> Identity VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          lvl -> Identity lvl
forall (f :: * -> *) a. Applicative f => a -> f a
pure

  addOpWisdom :: SegOp lvl rep -> OpWithWisdom (SegOp lvl rep)
addOpWisdom = Identity (SegOp lvl (Wise rep)) -> SegOp lvl (Wise rep)
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl (Wise rep)) -> SegOp lvl (Wise rep))
-> (SegOp lvl rep -> Identity (SegOp lvl (Wise rep)))
-> SegOp lvl rep
-> SegOp lvl (Wise rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegOpMapper lvl rep (Wise rep) Identity
-> SegOp lvl rep -> Identity (SegOp lvl (Wise rep))
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep (Wise rep) Identity
forall lvl. SegOpMapper lvl rep (Wise rep) Identity
add
    where
      add :: SegOpMapper lvl rep (Wise rep) Identity
add =
        (SubExp -> Identity SubExp)
-> (Lambda rep -> Identity (Lambda (Wise rep)))
-> (KernelBody rep -> Identity (KernelBody (Wise rep)))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl rep (Wise rep) Identity
forall lvl frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (KernelBody frep -> m (KernelBody trep))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl frep trep m
SegOpMapper
          SubExp -> Identity SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          (Lambda (Wise rep) -> Identity (Lambda (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Wise rep) -> Identity (Lambda (Wise rep)))
-> (Lambda rep -> Lambda (Wise rep))
-> Lambda rep
-> Identity (Lambda (Wise rep))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> Lambda (Wise rep)
forall rep. Informing rep => Lambda rep -> Lambda (Wise rep)
informLambda)
          (KernelBody (Wise rep) -> Identity (KernelBody (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody (Wise rep) -> Identity (KernelBody (Wise rep)))
-> (KernelBody rep -> KernelBody (Wise rep))
-> KernelBody rep
-> Identity (KernelBody (Wise rep))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody rep -> KernelBody (Wise rep)
forall rep.
Informing rep =>
KernelBody rep -> KernelBody (Wise rep)
informKernelBody)
          VName -> Identity VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          lvl -> Identity lvl
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance ASTRep rep => ST.IndexOp (SegOp lvl rep) where
  indexOp :: SymbolTable rep
-> Int -> SegOp lvl rep -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k (SegMap lvl
_ SegSpace
space [Type]
_ KernelBody rep
kbody) [TPrimExp Int64 VName]
is = do
    Returns ResultManifest
ResultMaySimplify Certs
_ SubExp
se <- Int -> [KernelResult] -> Maybe KernelResult
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
k ([KernelResult] -> Maybe KernelResult)
-> [KernelResult] -> Maybe KernelResult
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
is
    let idx_table :: Map VName Indexed
idx_table = [(VName, Indexed)] -> Map VName Indexed
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Indexed)] -> Map VName Indexed)
-> [(VName, Indexed)] -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ [VName] -> [Indexed] -> [(VName, Indexed)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids ([Indexed] -> [(VName, Indexed)])
-> [Indexed] -> [(VName, Indexed)]
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> Indexed)
-> [TPrimExp Int64 VName] -> [Indexed]
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> PrimExp VName -> Indexed
ST.Indexed Certs
forall a. Monoid a => a
mempty (PrimExp VName -> Indexed)
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> Indexed
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped) [TPrimExp Int64 VName]
is
        idx_table' :: Map VName Indexed
idx_table' = (Map VName Indexed -> Stm rep -> Map VName Indexed)
-> Map VName Indexed -> Seq (Stm rep) -> Map VName Indexed
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map VName Indexed -> Stm rep -> Map VName Indexed
expandIndexedTable Map VName Indexed
idx_table (Seq (Stm rep) -> Map VName Indexed)
-> Seq (Stm rep) -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> Seq (Stm rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
kbody
    case SubExp
se of
      Var VName
v -> VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
idx_table'
      SubExp
_ -> Maybe Indexed
forall a. Maybe a
Nothing
    where
      ([VName]
gtids, [SubExp]
_) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      -- Indexes in excess of what is used to index through the
      -- segment dimensions.
      excess_is :: [TPrimExp Int64 VName]
excess_is = Int -> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. Int -> [a] -> [a]
drop ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) [TPrimExp Int64 VName]
is

      expandIndexedTable :: Map VName Indexed -> Stm rep -> Map VName Indexed
expandIndexedTable Map VName Indexed
table Stm rep
stm
        | [VName
v] <- Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName]) -> Pat (LetDec rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
          Just (PrimExp VName
pe, Certs
cs) <-
            WriterT Certs Maybe (PrimExp VName) -> Maybe (PrimExp VName, Certs)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certs Maybe (PrimExp VName)
 -> Maybe (PrimExp VName, Certs))
-> WriterT Certs Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certs)
forall a b. (a -> b) -> a -> b
$ (VName -> WriterT Certs Maybe (PrimExp VName))
-> Exp rep -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) rep v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table) (Exp rep -> WriterT Certs Maybe (PrimExp VName))
-> Exp rep -> WriterT Certs Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
            VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Certs -> PrimExp VName -> Indexed
ST.Indexed (Stm rep -> Certs
forall rep. Stm rep -> Certs
stmCerts Stm rep
stm Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) PrimExp VName
pe) Map VName Indexed
table
        | [VName
v] <- Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName]) -> Pat (LetDec rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
          BasicOp (Index VName
arr Slice SubExp
slice) <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
          [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
excess_is,
          VName
arr VName -> SymbolTable rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.available` SymbolTable rep
vtable,
          Just (Slice (PrimExp VName)
slice', Certs
cs) <- Map VName Indexed
-> Slice SubExp -> Maybe (Slice (PrimExp VName), Certs)
asPrimExpSlice Map VName Indexed
table Slice SubExp
slice =
            let idx :: Indexed
idx =
                  Certs -> VName -> [TPrimExp Int64 VName] -> Indexed
ST.IndexedArray
                    (Stm rep -> Certs
forall rep. Stm rep -> Certs
stmCerts Stm rep
stm Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs)
                    VName
arr
                    (Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((PrimExp VName -> TPrimExp Int64 VName)
-> Slice (PrimExp VName) -> Slice (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 Slice (PrimExp VName)
slice') [TPrimExp Int64 VName]
excess_is)
             in VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Indexed
idx Map VName Indexed
table
        | Bool
otherwise =
            Map VName Indexed
table

      asPrimExpSlice :: Map VName Indexed
-> Slice SubExp -> Maybe (Slice (PrimExp VName), Certs)
asPrimExpSlice Map VName Indexed
table =
        WriterT Certs Maybe (Slice (PrimExp VName))
-> Maybe (Slice (PrimExp VName), Certs)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certs Maybe (Slice (PrimExp VName))
 -> Maybe (Slice (PrimExp VName), Certs))
-> (Slice SubExp -> WriterT Certs Maybe (Slice (PrimExp VName)))
-> Slice SubExp
-> Maybe (Slice (PrimExp VName), Certs)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SubExp -> WriterT Certs Maybe (PrimExp VName))
-> Slice SubExp -> WriterT Certs Maybe (Slice (PrimExp VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((VName -> WriterT Certs Maybe (PrimExp VName))
-> SubExp -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM (Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table))

      asPrimExp :: Map VName Indexed -> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table VName
v
        | Just (ST.Indexed Certs
cs PrimExp VName
e) <- VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
table = Certs -> WriterT Certs Maybe ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certs
cs WriterT Certs Maybe ()
-> WriterT Certs Maybe (PrimExp VName)
-> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimExp VName
e
        | Just (Prim PrimType
pt) <- VName -> SymbolTable rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable =
            PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimExp VName -> WriterT Certs Maybe (PrimExp VName))
-> PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
        | Bool
otherwise = Maybe (PrimExp VName) -> WriterT Certs Maybe (PrimExp VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe (PrimExp VName)
forall a. Maybe a
Nothing
  indexOp SymbolTable rep
_ Int
_ SegOp lvl rep
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing

instance
  (ASTRep rep, ASTConstraints lvl) =>
  IsOp (SegOp lvl rep)
  where
  cheapOp :: SegOp lvl rep -> Bool
cheapOp SegOp lvl rep
_ = Bool
False
  safeOp :: SegOp lvl rep -> Bool
safeOp SegOp lvl rep
_ = Bool
True

--- Simplification

instance Engine.Simplifiable SplitOrdering where
  simplify :: SplitOrdering -> SimpleM rep SplitOrdering
simplify SplitOrdering
SplitContiguous =
    SplitOrdering -> SimpleM rep SplitOrdering
forall (f :: * -> *) a. Applicative f => a -> f a
pure SplitOrdering
SplitContiguous
  simplify (SplitStrided SubExp
stride) =
    SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering)
-> SimpleM rep SubExp -> SimpleM rep SplitOrdering
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
stride

instance Engine.Simplifiable SegSpace where
  simplify :: SegSpace -> SimpleM rep SegSpace
simplify (SegSpace VName
phys [(VName, SubExp)]
dims) =
    VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
phys ([(VName, SubExp)] -> SegSpace)
-> SimpleM rep [(VName, SubExp)] -> SimpleM rep SegSpace
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, SubExp) -> SimpleM rep (VName, SubExp))
-> [(VName, SubExp)] -> SimpleM rep [(VName, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> SimpleM rep SubExp)
-> (VName, SubExp) -> SimpleM rep (VName, SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify) [(VName, SubExp)]
dims

instance Engine.Simplifiable KernelResult where
  simplify :: KernelResult -> SimpleM rep KernelResult
simplify (Returns ResultManifest
manifest Certs
cs SubExp
what) =
    ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
manifest (Certs -> SubExp -> KernelResult)
-> SimpleM rep Certs -> SimpleM rep (SubExp -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs SimpleM rep (SubExp -> KernelResult)
-> SimpleM rep SubExp -> SimpleM rep KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
what
  simplify (WriteReturns Certs
cs Shape
ws VName
a [(Slice SubExp, SubExp)]
res) =
    Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns
      (Certs
 -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM rep Certs
-> SimpleM
     rep (Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs
      SimpleM
  rep (Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM rep Shape
-> SimpleM rep (VName -> [(Slice SubExp, SubExp)] -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Shape -> SimpleM rep Shape
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Shape
ws
      SimpleM rep (VName -> [(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM rep VName
-> SimpleM rep ([(Slice SubExp, SubExp)] -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
a
      SimpleM rep ([(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM rep [(Slice SubExp, SubExp)] -> SimpleM rep KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(Slice SubExp, SubExp)] -> SimpleM rep [(Slice SubExp, SubExp)]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(Slice SubExp, SubExp)]
res
  simplify (ConcatReturns Certs
cs SplitOrdering
o SubExp
w SubExp
pte VName
what) =
    Certs -> SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns
      (Certs
 -> SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult)
-> SimpleM rep Certs
-> SimpleM
     rep (SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs
      SimpleM
  rep (SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult)
-> SimpleM rep SplitOrdering
-> SimpleM rep (SubExp -> SubExp -> VName -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SplitOrdering -> SimpleM rep SplitOrdering
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SplitOrdering
o
      SimpleM rep (SubExp -> SubExp -> VName -> KernelResult)
-> SimpleM rep SubExp
-> SimpleM rep (SubExp -> VName -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
      SimpleM rep (SubExp -> VName -> KernelResult)
-> SimpleM rep SubExp -> SimpleM rep (VName -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
pte
      SimpleM rep (VName -> KernelResult)
-> SimpleM rep VName -> SimpleM rep KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
what
  simplify (TileReturns Certs
cs [(SubExp, SubExp)]
dims VName
what) =
    Certs -> [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns (Certs -> [(SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep Certs
-> SimpleM rep ([(SubExp, SubExp)] -> VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs SimpleM rep ([(SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep [(SubExp, SubExp)]
-> SimpleM rep (VName -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(SubExp, SubExp)] -> SimpleM rep [(SubExp, SubExp)]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(SubExp, SubExp)]
dims SimpleM rep (VName -> KernelResult)
-> SimpleM rep VName -> SimpleM rep KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
what
  simplify (RegTileReturns Certs
cs [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
what) =
    Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult
RegTileReturns
      (Certs -> [(SubExp, SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep Certs
-> SimpleM
     rep ([(SubExp, SubExp, SubExp)] -> VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Certs
cs
      SimpleM rep ([(SubExp, SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM rep [(SubExp, SubExp, SubExp)]
-> SimpleM rep (VName -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(SubExp, SubExp, SubExp)]
-> SimpleM rep [(SubExp, SubExp, SubExp)]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(SubExp, SubExp, SubExp)]
dims_n_tiles
      SimpleM rep (VName -> KernelResult)
-> SimpleM rep VName -> SimpleM rep KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
what

mkWiseKernelBody ::
  (ASTRep rep, CanBeWise (Op rep)) =>
  BodyDec rep ->
  Stms (Wise rep) ->
  [KernelResult] ->
  KernelBody (Wise rep)
mkWiseKernelBody :: BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody BodyDec rep
dec Stms (Wise rep)
stms [KernelResult]
res =
  let Body BodyDec (Wise rep)
dec' Stms (Wise rep)
_ Result
_ = BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
mkWiseBody BodyDec rep
dec Stms (Wise rep)
stms (Result -> Body (Wise rep)) -> Result -> Body (Wise rep)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
res_vs
   in BodyDec (Wise rep)
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Wise rep)
dec' Stms (Wise rep)
stms [KernelResult]
res
  where
    res_vs :: [SubExp]
res_vs = (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
res

mkKernelBodyM ::
  MonadBuilder m =>
  Stms (Rep m) ->
  [KernelResult] ->
  m (KernelBody (Rep m))
mkKernelBodyM :: Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms (Rep m)
stms [KernelResult]
kres = do
  Body BodyDec (Rep m)
dec' Stms (Rep m)
_ Result
_ <- Stms (Rep m) -> Result -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
stms (Result -> m (Body (Rep m))) -> Result -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
res_ses
  KernelBody (Rep m) -> m (KernelBody (Rep m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody (Rep m) -> m (KernelBody (Rep m)))
-> KernelBody (Rep m) -> m (KernelBody (Rep m))
forall a b. (a -> b) -> a -> b
$ BodyDec (Rep m)
-> Stms (Rep m) -> [KernelResult] -> KernelBody (Rep m)
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Rep m)
dec' Stms (Rep m)
stms [KernelResult]
kres
  where
    res_ses :: [SubExp]
res_ses = (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
kres

simplifyKernelBody ::
  (Engine.SimplifiableRep rep, BodyDec rep ~ ()) =>
  SegSpace ->
  KernelBody (Wise rep) ->
  Engine.SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody :: SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space (KernelBody BodyDec (Wise rep)
_ Stms (Wise rep)
stms [KernelResult]
res) = do
  BlockPred (Wise rep)
par_blocker <- (Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep))
forall rep a. (Env rep -> a) -> SimpleM rep a
Engine.asksEngineEnv ((Env rep -> BlockPred (Wise rep))
 -> SimpleM rep (BlockPred (Wise rep)))
-> (Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep))
forall a b. (a -> b) -> a -> b
$ HoistBlockers rep -> BlockPred (Wise rep)
forall rep. HoistBlockers rep -> BlockPred (Wise rep)
Engine.blockHoistPar (HoistBlockers rep -> BlockPred (Wise rep))
-> (Env rep -> HoistBlockers rep)
-> Env rep
-> BlockPred (Wise rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Env rep -> HoistBlockers rep
forall rep. Env rep -> HoistBlockers rep
Engine.envHoistBlockers

  let blocker :: BlockPred (Wise rep)
blocker =
        Names -> BlockPred (Wise rep)
forall rep. ASTRep rep => Names -> BlockPred rep
Engine.hasFree Names
bound_here
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. BlockPred rep
Engine.isOp
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
par_blocker
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. BlockPred rep
Engine.isConsumed
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. Aliased rep => BlockPred rep
Engine.isConsuming
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`Engine.orIf` BlockPred (Wise rep)
forall rep. SimplifiableRep rep => BlockPred (Wise rep)
Engine.isDeviceMigrated

  -- Ensure we do not try to use anything that is consumed in the result.
  ([KernelResult]
body_res, Stms (Wise rep)
body_stms, Stms (Wise rep)
hoisted) <-
    (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable ((SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep))
-> [VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((SymbolTable (Wise rep) -> VName -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> VName -> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep. VName -> SymbolTable rep -> SymbolTable rep
ST.consume)) ((KernelResult -> [VName]) -> [KernelResult] -> [VName]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap KernelResult -> [VName]
consumedInResult [KernelResult]
res))
      (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
 -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
    -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable)
      (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
 -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
    -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True})
      (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
 -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
    -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall rep a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop
      (SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
 -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep ([KernelResult], UsageTable)
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> Stms (Wise rep)
-> SimpleM rep (a, UsageTable)
-> SimpleM rep (a, Stms (Wise rep), Stms (Wise rep))
Engine.blockIf BlockPred (Wise rep)
blocker Stms (Wise rep)
stms
      (SimpleM rep ([KernelResult], UsageTable)
 -> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep ([KernelResult], UsageTable)
-> SimpleM rep ([KernelResult], Stms (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ do
        [KernelResult]
res' <-
          (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep [KernelResult] -> SimpleM rep [KernelResult]
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (Names -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep. Names -> SymbolTable rep -> SymbolTable rep
ST.hideCertified (Names -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> Names -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo (Wise rep)) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo (Wise rep)) -> [VName])
-> Map VName (NameInfo (Wise rep)) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms (Wise rep) -> Map VName (NameInfo (Wise rep))
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms (Wise rep)
stms) (SimpleM rep [KernelResult] -> SimpleM rep [KernelResult])
-> SimpleM rep [KernelResult] -> SimpleM rep [KernelResult]
forall a b. (a -> b) -> a -> b
$
            (KernelResult -> SimpleM rep KernelResult)
-> [KernelResult] -> SimpleM rep [KernelResult]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelResult -> SimpleM rep KernelResult
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [KernelResult]
res
        ([KernelResult], UsageTable)
-> SimpleM rep ([KernelResult], UsageTable)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([KernelResult]
res', Names -> UsageTable
UT.usages (Names -> UsageTable) -> Names -> UsageTable
forall a b. (a -> b) -> a -> b
$ [KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn [KernelResult]
res')

  (KernelBody (Wise rep), Stms (Wise rep))
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep
-> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep)
mkWiseKernelBody () Stms (Wise rep)
body_stms [KernelResult]
body_res, Stms (Wise rep)
hoisted)
  where
    scope_vtable :: SymbolTable (Wise rep)
scope_vtable = SegSpace -> SymbolTable (Wise rep)
forall rep. ASTRep rep => SegSpace -> SymbolTable rep
segSpaceSymbolTable SegSpace
space
    bound_here :: Names
bound_here = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space

    consumedInResult :: KernelResult -> [VName]
consumedInResult (WriteReturns Certs
_ Shape
_ VName
arr [(Slice SubExp, SubExp)]
_) =
      [VName
arr]
    consumedInResult KernelResult
_ =
      []

simplifyLambda ::
  Engine.SimplifiableRep rep =>
  Lambda (Wise rep) ->
  Engine.SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda :: Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda = SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.blockMigrated (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> (Lambda (Wise rep)
    -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda

segSpaceSymbolTable :: ASTRep rep => SegSpace -> ST.SymbolTable rep
segSpaceSymbolTable :: SegSpace -> SymbolTable rep
segSpaceSymbolTable (SegSpace VName
flat [(VName, SubExp)]
gtids_and_dims) =
  (SymbolTable rep -> (VName, SubExp) -> SymbolTable rep)
-> SymbolTable rep -> [(VName, SubExp)] -> SymbolTable rep
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
forall rep.
ASTRep rep =>
SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
f (Scope rep -> SymbolTable rep
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope (Scope rep -> SymbolTable rep) -> Scope rep -> SymbolTable rep
forall a b. (a -> b) -> a -> b
$ VName -> NameInfo rep -> Scope rep
forall k a. k -> a -> Map k a
M.singleton VName
flat (NameInfo rep -> Scope rep) -> NameInfo rep -> Scope rep
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo rep
forall rep. IntType -> NameInfo rep
IndexName IntType
Int64) [(VName, SubExp)]
gtids_and_dims
  where
    f :: SymbolTable rep -> (VName, SubExp) -> SymbolTable rep
f SymbolTable rep
vtable (VName
gtid, SubExp
dim) = VName -> IntType -> SubExp -> SymbolTable rep -> SymbolTable rep
forall rep.
ASTRep rep =>
VName -> IntType -> SubExp -> SymbolTable rep -> SymbolTable rep
ST.insertLoopVar VName
gtid IntType
Int64 SubExp
dim SymbolTable rep
vtable

simplifySegBinOp ::
  Engine.SimplifiableRep rep =>
  SegBinOp (Wise rep) ->
  Engine.SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp :: SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp (SegBinOp Commutativity
comm Lambda (Wise rep)
lam [SubExp]
nes Shape
shape) = do
  (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <-
    (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True}) (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
      Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda Lambda (Wise rep)
lam
  Shape
shape' <- Shape -> SimpleM rep Shape
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Shape
shape
  [SubExp]
nes' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
  (SegBinOp (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Commutativity
-> Lambda (Wise rep) -> [SubExp] -> Shape -> SegBinOp (Wise rep)
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm Lambda (Wise rep)
lam' [SubExp]
nes' Shape
shape', Stms (Wise rep)
hoisted)

-- | Simplify the given 'SegOp'.
simplifySegOp ::
  ( Engine.SimplifiableRep rep,
    BodyDec rep ~ (),
    Engine.Simplifiable lvl
  ) =>
  SegOp lvl (Wise rep) ->
  Engine.SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp :: SegOp lvl (Wise rep)
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Wise rep)
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
  (KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody
  (SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( lvl
-> SegSpace
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl' SegSpace
space' [Type]
ts' KernelBody (Wise rep)
kbody',
      Stms (Wise rep)
body_hoisted
    )
simplifySegOp (SegRed lvl
lvl SegSpace
space [SegBinOp (Wise rep)]
reds [Type]
ts KernelBody (Wise rep)
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
  ([SegBinOp (Wise rep)]
reds', [Stms (Wise rep)]
reds_hoisted) <-
    (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) (SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
 -> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
      [(SegBinOp (Wise rep), Stms (Wise rep))]
-> ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegBinOp (Wise rep), Stms (Wise rep))]
 -> ([SegBinOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(SegBinOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegBinOp (Wise rep)
 -> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep)))
-> [SegBinOp (Wise rep)]
-> SimpleM rep [(SegBinOp (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp [SegBinOp (Wise rep)]
reds
  (KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody

  (SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( lvl
-> SegSpace
-> [SegBinOp (Wise rep)]
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl' SegSpace
space' [SegBinOp (Wise rep)]
reds' [Type]
ts' KernelBody (Wise rep)
kbody',
      [Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
reds_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
    )
  where
    scope :: Scope (Wise rep)
scope = SegSpace -> Scope (Wise rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
    scope_vtable :: SymbolTable (Wise rep)
scope_vtable = Scope (Wise rep) -> SymbolTable (Wise rep)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
simplifySegOp (SegScan lvl
lvl SegSpace
space [SegBinOp (Wise rep)]
scans [Type]
ts KernelBody (Wise rep)
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
  ([SegBinOp (Wise rep)]
scans', [Stms (Wise rep)]
scans_hoisted) <-
    (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) (SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
 -> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
      [(SegBinOp (Wise rep), Stms (Wise rep))]
-> ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegBinOp (Wise rep), Stms (Wise rep))]
 -> ([SegBinOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(SegBinOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([SegBinOp (Wise rep)], [Stms (Wise rep)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegBinOp (Wise rep)
 -> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep)))
-> [SegBinOp (Wise rep)]
-> SimpleM rep [(SegBinOp (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
SegBinOp (Wise rep)
-> SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep))
simplifySegBinOp [SegBinOp (Wise rep)]
scans
  (KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody

  (SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( lvl
-> SegSpace
-> [SegBinOp (Wise rep)]
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl' SegSpace
space' [SegBinOp (Wise rep)]
scans' [Type]
ts' KernelBody (Wise rep)
kbody',
      [Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
scans_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
    )
  where
    scope :: Scope (Wise rep)
scope = SegSpace -> Scope (Wise rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
    scope_vtable :: SymbolTable (Wise rep)
scope_vtable = Scope (Wise rep) -> SymbolTable (Wise rep)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope
simplifySegOp (SegHist lvl
lvl SegSpace
space [HistOp (Wise rep)]
ops [Type]
ts KernelBody (Wise rep)
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM rep (lvl, SegSpace, [Type])
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)

  ([HistOp (Wise rep)]
ops', [Stms (Wise rep)]
ops_hoisted) <- ([(HistOp (Wise rep), Stms (Wise rep))]
 -> ([HistOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(HistOp (Wise rep), Stms (Wise rep))]
-> ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
 -> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
    [HistOp (Wise rep)]
-> (HistOp (Wise rep)
    -> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Wise rep)]
ops ((HistOp (Wise rep)
  -> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
 -> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))])
-> (HistOp (Wise rep)
    -> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
forall a b. (a -> b) -> a -> b
$
      \(HistOp Shape
w SubExp
rf [VName]
arrs [SubExp]
nes Shape
dims Lambda (Wise rep)
lam) -> do
        Shape
w' <- Shape -> SimpleM rep Shape
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Shape
w
        SubExp
rf' <- SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
rf
        [VName]
arrs' <- [VName] -> SimpleM rep [VName]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
arrs
        [SubExp]
nes' <- [SubExp] -> SimpleM rep [SubExp]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
        Shape
dims' <- Shape -> SimpleM rep Shape
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Shape
dims
        (Lambda (Wise rep)
lam', Stms (Wise rep)
op_hoisted) <-
          (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> SymbolTable (Wise rep)
scope_vtable) (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
            (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (\SymbolTable (Wise rep)
vtable -> SymbolTable (Wise rep)
vtable {simplifyMemory :: Bool
ST.simplifyMemory = Bool
True}) (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
              Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda Lambda (Wise rep)
lam
        (HistOp (Wise rep), Stms (Wise rep))
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda (Wise rep)
-> HistOp (Wise rep)
forall rep.
Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda rep
-> HistOp rep
HistOp Shape
w' SubExp
rf' [VName]
arrs' [SubExp]
nes' Shape
dims' Lambda (Wise rep)
lam',
            Stms (Wise rep)
op_hoisted
          )

  (KernelBody (Wise rep)
kbody', Stms (Wise rep)
body_hoisted) <- SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
forall rep.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SegSpace
-> KernelBody (Wise rep)
-> SimpleM rep (KernelBody (Wise rep), Stms (Wise rep))
simplifyKernelBody SegSpace
space KernelBody (Wise rep)
kbody

  (SegOp lvl (Wise rep), Stms (Wise rep))
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( lvl
-> SegSpace
-> [HistOp (Wise rep)]
-> [Type]
-> KernelBody (Wise rep)
-> SegOp lvl (Wise rep)
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl' SegSpace
space' [HistOp (Wise rep)]
ops' [Type]
ts' KernelBody (Wise rep)
kbody',
      [Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
ops_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
body_hoisted
    )
  where
    scope :: Scope (Wise rep)
scope = SegSpace -> Scope (Wise rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
    scope_vtable :: SymbolTable (Wise rep)
scope_vtable = Scope (Wise rep) -> SymbolTable (Wise rep)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
scope

-- | Does this rep contain 'SegOp's in its t'Op's?  A rep must be an
-- instance of this class for the simplification rules to work.
class HasSegOp rep where
  type SegOpLevel rep
  asSegOp :: Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
  segOp :: SegOp (SegOpLevel rep) rep -> Op rep

-- | Simplification rules for simplifying 'SegOp's.
segOpRules ::
  (HasSegOp rep, BuilderOps rep, Buildable rep, Aliased rep) =>
  RuleBook rep
segOpRules :: RuleBook rep
segOpRules =
  [TopDownRule rep] -> [BottomUpRule rep] -> RuleBook rep
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [RuleOp rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp rep (TopDown rep)
forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
TopDownRuleOp rep
segOpRuleTopDown] [RuleOp rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp rep (BottomUp rep)
forall rep.
(HasSegOp rep, BuilderOps rep, Aliased rep) =>
BottomUpRuleOp rep
segOpRuleBottomUp]

segOpRuleTopDown ::
  (HasSegOp rep, BuilderOps rep, Buildable rep) =>
  TopDownRuleOp rep
segOpRuleTopDown :: TopDownRuleOp rep
segOpRuleTopDown TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec Op rep
op
  | Just SegOp (SegOpLevel rep) rep
op' <- Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
forall rep.
HasSegOp rep =>
Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
asSegOp Op rep
op =
      TopDown rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
SymbolTable rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
topDownSegOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
op'
  | Bool
otherwise =
      Rule rep
forall rep. Rule rep
Skip

segOpRuleBottomUp ::
  (HasSegOp rep, BuilderOps rep, Aliased rep) =>
  BottomUpRuleOp rep
segOpRuleBottomUp :: BottomUpRuleOp rep
segOpRuleBottomUp BottomUp rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec Op rep
op
  | Just SegOp (SegOpLevel rep) rep
op' <- Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
forall rep.
HasSegOp rep =>
Op rep -> Maybe (SegOp (SegOpLevel rep) rep)
asSegOp Op rep
op =
      BottomUp rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
forall rep.
(Aliased rep, HasSegOp rep, BuilderOps rep) =>
(SymbolTable rep, UsageTable)
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
bottomUpSegOp BottomUp rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
op'
  | Bool
otherwise =
      Rule rep
forall rep. Rule rep
Skip

topDownSegOp ::
  (HasSegOp rep, BuilderOps rep, Buildable rep) =>
  ST.SymbolTable rep ->
  Pat (LetDec rep) ->
  StmAux (ExpDec rep) ->
  SegOp (SegOpLevel rep) rep ->
  Rule rep
-- If a SegOp produces something invariant to the SegOp, turn it
-- into a replicate.
topDownSegOp :: SymbolTable rep
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
topDownSegOp SymbolTable rep
vtable (Pat [PatElem (LetDec rep)]
kpes) StmAux (ExpDec rep)
dec (SegMap SegOpLevel rep
lvl SegSpace
space [Type]
ts (KernelBody BodyDec rep
_ Stms rep
kstms [KernelResult]
kres)) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
  ([Type]
ts', [PatElem (LetDec rep)]
kpes', [KernelResult]
kres') <-
    [(Type, PatElem (LetDec rep), KernelResult)]
-> ([Type], [PatElem (LetDec rep)], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Type, PatElem (LetDec rep), KernelResult)]
 -> ([Type], [PatElem (LetDec rep)], [KernelResult]))
-> RuleM rep [(Type, PatElem (LetDec rep), KernelResult)]
-> RuleM rep ([Type], [PatElem (LetDec rep)], [KernelResult])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool)
-> [(Type, PatElem (LetDec rep), KernelResult)]
-> RuleM rep [(Type, PatElem (LetDec rep), KernelResult)]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool
checkForInvarianceResult ([Type]
-> [PatElem (LetDec rep)]
-> [KernelResult]
-> [(Type, PatElem (LetDec rep), KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
ts [PatElem (LetDec rep)]
kpes [KernelResult]
kres)

  -- Check if we did anything at all.
  Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([KernelResult]
kres [KernelResult] -> [KernelResult] -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult]
kres') RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify

  KernelBody rep
kbody <- Stms (Rep (RuleM rep))
-> [KernelResult] -> RuleM rep (KernelBody (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms rep
Stms (Rep (RuleM rep))
kstms [KernelResult]
kres'
  Stm (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (RuleM rep)) -> RuleM rep ())
-> Stm (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
    Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
kpes') StmAux (ExpDec rep)
dec (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$
      Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (Op rep -> Exp rep) -> Op rep -> Exp rep
forall a b. (a -> b) -> a -> b
$
        SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> Op rep)
-> SegOp (SegOpLevel rep) rep -> Op rep
forall a b. (a -> b) -> a -> b
$
          SegOpLevel rep
-> SegSpace
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space [Type]
ts' KernelBody rep
kbody
  where
    isInvariant :: SubExp -> Bool
isInvariant Constant {} = Bool
True
    isInvariant (Var VName
v) = Maybe (Entry rep) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry rep) -> Bool) -> Maybe (Entry rep) -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> SymbolTable rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
v SymbolTable rep
vtable

    checkForInvarianceResult :: (Type, PatElem (LetDec rep), KernelResult) -> RuleM rep Bool
checkForInvarianceResult (Type
_, PatElem (LetDec rep)
pe, Returns ResultManifest
rm Certs
cs SubExp
se)
      | Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty,
        ResultManifest
rm ResultManifest -> ResultManifest -> Bool
forall a. Eq a => a -> a -> Bool
== ResultManifest
ResultMaySimplify,
        SubExp -> Bool
isInvariant SubExp
se = do
          [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
            BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$
              Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space) SubExp
se
          Bool -> RuleM rep Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
    checkForInvarianceResult (Type, PatElem (LetDec rep), KernelResult)
_ =
      Bool -> RuleM rep Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True

-- If a SegRed contains two reduction operations that have the same
-- vector shape, merge them together.  This saves on communication
-- overhead, but can in principle lead to more local memory usage.
topDownSegOp SymbolTable rep
_ (Pat [PatElem (LetDec rep)]
pes) StmAux (ExpDec rep)
_ (SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
ts KernelBody rep
kbody)
  | [SegBinOp rep] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegBinOp rep]
ops Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
    [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings <-
      ((SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
 -> (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
 -> Bool)
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
-> (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
-> Bool
forall rep b rep b. (SegBinOp rep, b) -> (SegBinOp rep, b) -> Bool
sameShape ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
 -> [[(SegBinOp rep,
       [(PatElem (LetDec rep), Type, KernelResult)])]])
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
forall a b. (a -> b) -> a -> b
$
        [SegBinOp rep]
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp rep]
ops ([[(PatElem (LetDec rep), Type, KernelResult)]]
 -> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])])
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
forall a b. (a -> b) -> a -> b
$
          [Int]
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp rep -> Int) -> [SegBinOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp rep -> [SubExp]) -> SegBinOp rep -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp rep]
ops) ([(PatElem (LetDec rep), Type, KernelResult)]
 -> [[(PatElem (LetDec rep), Type, KernelResult)]])
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> [[(PatElem (LetDec rep), Type, KernelResult)]]
forall a b. (a -> b) -> a -> b
$
            [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> [(PatElem (LetDec rep), Type, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
red_pes [Type]
red_ts [KernelResult]
red_res,
    ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
 -> Bool)
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
-> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1) (Int -> Bool)
-> ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
    -> Int)
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      let ([SegBinOp rep]
ops', [[(PatElem (LetDec rep), Type, KernelResult)]]
aux) = [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> ([SegBinOp rep], [[(PatElem (LetDec rep), Type, KernelResult)]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
 -> ([SegBinOp rep],
     [[(PatElem (LetDec rep), Type, KernelResult)]]))
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> ([SegBinOp rep], [[(PatElem (LetDec rep), Type, KernelResult)]])
forall a b. (a -> b) -> a -> b
$ ([(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
 -> Maybe
      (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)]))
-> [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
-> [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]
-> Maybe
     (SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])
forall rep a.
Buildable rep =>
[(SegBinOp rep, [a])] -> Maybe (SegBinOp rep, [a])
combineOps [[(SegBinOp rep, [(PatElem (LetDec rep), Type, KernelResult)])]]
op_groupings
          ([PatElem (LetDec rep)]
red_pes', [Type]
red_ts', [KernelResult]
red_res') = [(PatElem (LetDec rep), Type, KernelResult)]
-> ([PatElem (LetDec rep)], [Type], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(PatElem (LetDec rep), Type, KernelResult)]
 -> ([PatElem (LetDec rep)], [Type], [KernelResult]))
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> ([PatElem (LetDec rep)], [Type], [KernelResult])
forall a b. (a -> b) -> a -> b
$ [[(PatElem (LetDec rep), Type, KernelResult)]]
-> [(PatElem (LetDec rep), Type, KernelResult)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(PatElem (LetDec rep), Type, KernelResult)]]
aux
          pes' :: [PatElem (LetDec rep)]
pes' = [PatElem (LetDec rep)]
red_pes' [PatElem (LetDec rep)]
-> [PatElem (LetDec rep)] -> [PatElem (LetDec rep)]
forall a. [a] -> [a] -> [a]
++ [PatElem (LetDec rep)]
map_pes
          ts' :: [Type]
ts' = [Type]
red_ts' [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
map_ts
          kbody' :: KernelBody rep
kbody' = KernelBody rep
kbody {kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult]
red_res' [KernelResult] -> [KernelResult] -> [KernelResult]
forall a. [a] -> [a] -> [a]
++ [KernelResult]
map_res}
      Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pes') (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (Op rep -> Exp rep) -> Op rep -> Exp rep
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> Op rep)
-> SegOp (SegOpLevel rep) rep -> Op rep
forall a b. (a -> b) -> a -> b
$ SegOpLevel rep
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops' [Type]
ts' KernelBody rep
kbody'
  where
    ([PatElem (LetDec rep)]
red_pes, [PatElem (LetDec rep)]
map_pes) = Int
-> [PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) [PatElem (LetDec rep)]
pes
    ([Type]
red_ts, [Type]
map_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) [Type]
ts
    ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody rep
kbody

    sameShape :: (SegBinOp rep, b) -> (SegBinOp rep, b) -> Bool
sameShape (SegBinOp rep
op1, b
_) (SegBinOp rep
op2, b
_) = SegBinOp rep -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp rep
op1 Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
== SegBinOp rep -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp rep
op2

    combineOps :: [(SegBinOp rep, [a])] -> Maybe (SegBinOp rep, [a])
combineOps [] = Maybe (SegBinOp rep, [a])
forall a. Maybe a
Nothing
    combineOps ((SegBinOp rep, [a])
x : [(SegBinOp rep, [a])]
xs) = (SegBinOp rep, [a]) -> Maybe (SegBinOp rep, [a])
forall a. a -> Maybe a
Just ((SegBinOp rep, [a]) -> Maybe (SegBinOp rep, [a]))
-> (SegBinOp rep, [a]) -> Maybe (SegBinOp rep, [a])
forall a b. (a -> b) -> a -> b
$ ((SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a]))
-> (SegBinOp rep, [a])
-> [(SegBinOp rep, [a])]
-> (SegBinOp rep, [a])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
forall rep a.
Buildable rep =>
(SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
combine (SegBinOp rep, [a])
x [(SegBinOp rep, [a])]
xs

    combine :: (SegBinOp rep, [a]) -> (SegBinOp rep, [a]) -> (SegBinOp rep, [a])
combine (SegBinOp rep
op1, [a]
op1_aux) (SegBinOp rep
op2, [a]
op2_aux) =
      let lam1 :: Lambda rep
lam1 = SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op1
          lam2 :: Lambda rep
lam2 = SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op2
          ([Param (LParamInfo rep)]
op1_xparams, [Param (LParamInfo rep)]
op1_yparams) =
            Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op1)) ([Param (LParamInfo rep)]
 -> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam1
          ([Param (LParamInfo rep)]
op2_xparams, [Param (LParamInfo rep)]
op2_yparams) =
            Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op2)) ([Param (LParamInfo rep)]
 -> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam2
          lam :: Lambda rep
lam =
            Lambda :: forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda
              { lambdaParams :: [Param (LParamInfo rep)]
lambdaParams =
                  [Param (LParamInfo rep)]
op1_xparams
                    [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op2_xparams
                    [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op1_yparams
                    [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
op2_yparams,
                lambdaReturnType :: [Type]
lambdaReturnType = Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam1 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam2,
                lambdaBody :: Body rep
lambdaBody =
                  Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam1) Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam2)) (Result -> Body rep) -> Result -> Body rep
forall a b. (a -> b) -> a -> b
$
                    Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam1) Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam2)
              }
       in ( SegBinOp :: forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp
              { segBinOpComm :: Commutativity
segBinOpComm = SegBinOp rep -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm SegBinOp rep
op1 Commutativity -> Commutativity -> Commutativity
forall a. Semigroup a => a -> a -> a
<> SegBinOp rep -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm SegBinOp rep
op2,
                segBinOpLambda :: Lambda rep
segBinOpLambda = Lambda rep
lam,
                segBinOpNeutral :: [SubExp]
segBinOpNeutral = SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ SegBinOp rep -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp rep
op2,
                segBinOpShape :: Shape
segBinOpShape = SegBinOp rep -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp rep
op1 -- Same as shape of op2 due to the grouping.
              },
            [a]
op1_aux [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
op2_aux
          )
topDownSegOp SymbolTable rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ SegOp (SegOpLevel rep) rep
_ = Rule rep
forall rep. Rule rep
Skip

-- A convenient way of operating on the type and body of a SegOp,
-- without worrying about exactly what kind it is.
segOpGuts ::
  SegOp (SegOpLevel rep) rep ->
  ( [Type],
    KernelBody rep,
    Int,
    [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
  )
segOpGuts :: SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
    [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
segOpGuts (SegMap SegOpLevel rep
lvl SegSpace
space [Type]
kts KernelBody rep
body) =
  ([Type]
kts, KernelBody rep
body, Int
0, SegOpLevel rep
-> SegSpace
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space)
segOpGuts (SegScan SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
kts KernelBody rep
body) =
  ([Type]
kts, KernelBody rep
body, [SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops, SegOpLevel rep
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops)
segOpGuts (SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops [Type]
kts KernelBody rep
body) =
  ([Type]
kts, KernelBody rep
body, [SegBinOp rep] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp rep]
ops, SegOpLevel rep
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
space [SegBinOp rep]
ops)
segOpGuts (SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops [Type]
kts KernelBody rep
body) =
  ([Type]
kts, KernelBody rep
body, [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> Int) -> [HistOp rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> (HistOp rep -> [VName]) -> HistOp rep -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp rep]
ops, SegOpLevel rep
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops)

bottomUpSegOp ::
  (Aliased rep, HasSegOp rep, BuilderOps rep) =>
  (ST.SymbolTable rep, UT.UsageTable) ->
  Pat (LetDec rep) ->
  StmAux (ExpDec rep) ->
  SegOp (SegOpLevel rep) rep ->
  Rule rep
-- Some SegOp results can be moved outside the SegOp, which can
-- simplify further analysis.
bottomUpSegOp :: (SymbolTable rep, UsageTable)
-> Pat (LetDec rep)
-> StmAux (ExpDec rep)
-> SegOp (SegOpLevel rep) rep
-> Rule rep
bottomUpSegOp (SymbolTable rep
vtable, UsageTable
used) (Pat [PatElem (LetDec rep)]
kpes) StmAux (ExpDec rep)
dec SegOp (SegOpLevel rep) rep
segop = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
  -- Iterate through the bindings.  For each, we check whether it is
  -- in kres and can be moved outside.  If so, we remove it from kres
  -- and kpes and make it a binding outside.  We have to be careful
  -- not to remove anything that is passed on to a scan/map/histogram
  -- operation.  Fortunately, these are always first in the result
  -- list.
  ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') <-
    Scope rep
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (RuleM
   rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
 -> RuleM
      rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep))
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall a b. (a -> b) -> a -> b
$
      (([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
 -> Stm rep
 -> RuleM
      rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep))
-> ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stms rep
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stm rep
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
distribute ([PatElem (LetDec rep)]
kpes, [Type]
kts, [KernelResult]
kres, Stms rep
forall a. Monoid a => a
mempty) Stms rep
kstms

  Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
    ([PatElem (LetDec rep)]
kpes' [PatElem (LetDec rep)] -> [PatElem (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
kpes)
    RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify

  KernelBody rep
kbody' <-
    Scope rep
-> RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep))
-> RuleM rep (KernelBody rep) -> RuleM rep (KernelBody rep)
forall a b. (a -> b) -> a -> b
$
      Stms (Rep (RuleM rep))
-> [KernelResult] -> RuleM rep (KernelBody (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m))
mkKernelBodyM Stms rep
Stms (Rep (RuleM rep))
kstms' [KernelResult]
kres'

  Stm (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (RuleM rep)) -> RuleM rep ())
-> Stm (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
kpes') StmAux (ExpDec rep)
dec (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (Op rep -> Exp rep) -> Op rep -> Exp rep
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> Op rep)
-> SegOp (SegOpLevel rep) rep -> Op rep
forall a b. (a -> b) -> a -> b
$ [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop [Type]
kts' KernelBody rep
kbody'
  where
    ([Type]
kts, kbody :: KernelBody rep
kbody@(KernelBody BodyDec rep
_ Stms rep
kstms [KernelResult]
kres), Int
num_nonmap_results, [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep
mk_segop) =
      SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
    [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
forall rep.
SegOp (SegOpLevel rep) rep
-> ([Type], KernelBody rep, Int,
    [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep)
segOpGuts SegOp (SegOpLevel rep) rep
segop
    free_in_kstms :: Names
free_in_kstms = (Stm rep -> Names) -> Stms rep -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stms rep
kstms
    consumed_in_segop :: Names
consumed_in_segop = KernelBody rep -> Names
forall rep. Aliased rep => KernelBody rep -> Names
consumedInKernelBody KernelBody rep
kbody
    space :: SegSpace
space = SegOp (SegOpLevel rep) rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp (SegOpLevel rep) rep
segop

    sliceWithGtidsFixed :: Stm rep -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm rep
stm
      | Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (BasicOp (Index VName
arr Slice SubExp
slice)) <- Stm rep
stm,
        [DimIndex SubExp]
space_slice <- ((VName, SubExp) -> DimIndex SubExp)
-> [(VName, SubExp)] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> ((VName, SubExp) -> SubExp)
-> (VName, SubExp)
-> DimIndex SubExp
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> SubExp
Var (VName -> SubExp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> SubExp
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [DimIndex SubExp])
-> [(VName, SubExp)] -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space,
        [DimIndex SubExp]
space_slice [DimIndex SubExp] -> [DimIndex SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice,
        Slice SubExp
remaining_slice <- [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ Int -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. Int -> [a] -> [a]
drop ([DimIndex SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
space_slice) (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice),
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Maybe (Entry rep) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry rep) -> Bool)
-> (VName -> Maybe (Entry rep)) -> VName -> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName -> SymbolTable rep -> Maybe (Entry rep))
-> SymbolTable rep -> VName -> Maybe (Entry rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup SymbolTable rep
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
          Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
            VName -> Names
forall a. FreeIn a => a -> Names
freeIn VName
arr Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
remaining_slice =
          (Slice SubExp, VName) -> Maybe (Slice SubExp, VName)
forall a. a -> Maybe a
Just (Slice SubExp
remaining_slice, VName
arr)
      | Bool
otherwise =
          Maybe (Slice SubExp, VName)
forall a. Maybe a
Nothing

    distribute :: ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> Stm rep
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
distribute ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') Stm rep
stm
      | Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ Exp rep
_ <- Stm rep
stm,
        Just (Slice [DimIndex SubExp]
remaining_slice, VName
arr) <- Stm rep -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm rep
stm,
        Just (PatElem (LetDec rep)
kpe, [PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> PatElem (LetDec rep)
-> Maybe
     (PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
      [KernelResult])
isResult [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' PatElem (LetDec rep)
pe = do
          let outer_slice :: [DimIndex SubExp]
outer_slice =
                (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map
                  ( \SubExp
d ->
                      SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice
                        (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))
                        SubExp
d
                        (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))
                  )
                  ([SubExp] -> [DimIndex SubExp]) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
              index :: PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe' =
                [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe'] (Exp rep -> RuleM rep ())
-> (Slice SubExp -> Exp rep) -> Slice SubExp -> RuleM rep ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep)
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> Exp rep
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> RuleM rep ()) -> Slice SubExp -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
                  [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$
                    [DimIndex SubExp]
outer_slice [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. Semigroup a => a -> a -> a
<> [DimIndex SubExp]
remaining_slice
          if PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe
            VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
            Bool -> Bool -> Bool
|| VName
arr
            VName -> Names -> Bool
`nameIn` Names
consumed_in_segop
            then do
              VName
precopy <- String -> RuleM rep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM rep VName) -> String -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_precopy"
              PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe {patElemName :: VName
patElemName = VName
precopy}
              [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
kpe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
precopy
            else PatElem (LetDec rep) -> RuleM rep ()
index PatElem (LetDec rep)
kpe
          ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( [PatElem (LetDec rep)]
kpes'',
              [Type]
kts'',
              [KernelResult]
kres'',
              if PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> Names -> Bool
`nameIn` Names
free_in_kstms
                then Stms rep
kstms' Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm
                else Stms rep
kstms'
            )
    distribute ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms') Stm rep
stm =
      ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
-> RuleM
     rep ([PatElem (LetDec rep)], [Type], [KernelResult], Stms rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([PatElem (LetDec rep)]
kpes', [Type]
kts', [KernelResult]
kres', Stms rep
kstms' Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm)

    isResult :: [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> PatElem (LetDec rep)
-> Maybe
     (PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
      [KernelResult])
isResult [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' PatElem (LetDec rep)
pe =
      case ((PatElem (LetDec rep), Type, KernelResult) -> Bool)
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> ([(PatElem (LetDec rep), Type, KernelResult)],
    [(PatElem (LetDec rep), Type, KernelResult)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (PatElem (LetDec rep), Type, KernelResult) -> Bool
matches ([(PatElem (LetDec rep), Type, KernelResult)]
 -> ([(PatElem (LetDec rep), Type, KernelResult)],
     [(PatElem (LetDec rep), Type, KernelResult)]))
-> [(PatElem (LetDec rep), Type, KernelResult)]
-> ([(PatElem (LetDec rep), Type, KernelResult)],
    [(PatElem (LetDec rep), Type, KernelResult)])
forall a b. (a -> b) -> a -> b
$ [PatElem (LetDec rep)]
-> [Type]
-> [KernelResult]
-> [(PatElem (LetDec rep), Type, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
kpes' [Type]
kts' [KernelResult]
kres' of
        ([(PatElem (LetDec rep)
kpe, Type
_, KernelResult
_)], [(PatElem (LetDec rep), Type, KernelResult)]
kpes_and_kres)
          | Just Int
i <- PatElem (LetDec rep) -> [PatElem (LetDec rep)] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex PatElem (LetDec rep)
kpe [PatElem (LetDec rep)]
kpes,
            Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
num_nonmap_results,
            ([PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [(PatElem (LetDec rep), Type, KernelResult)]
-> ([PatElem (LetDec rep)], [Type], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(PatElem (LetDec rep), Type, KernelResult)]
kpes_and_kres ->
              (PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
 [KernelResult])
-> Maybe
     (PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
      [KernelResult])
forall a. a -> Maybe a
Just (PatElem (LetDec rep)
kpe, [PatElem (LetDec rep)]
kpes'', [Type]
kts'', [KernelResult]
kres'')
        ([(PatElem (LetDec rep), Type, KernelResult)],
 [(PatElem (LetDec rep), Type, KernelResult)])
_ -> Maybe
  (PatElem (LetDec rep), [PatElem (LetDec rep)], [Type],
   [KernelResult])
forall a. Maybe a
Nothing
      where
        matches :: (PatElem (LetDec rep), Type, KernelResult) -> Bool
matches (PatElem (LetDec rep)
_, Type
_, Returns ResultManifest
_ Certs
_ (Var VName
v)) = VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe
        matches (PatElem (LetDec rep), Type, KernelResult)
_ = Bool
False

--- Memory

kernelBodyReturns ::
  (Mem rep inner, HasScope rep m, Monad m) =>
  KernelBody somerep ->
  [ExpReturns] ->
  m [ExpReturns]
kernelBodyReturns :: KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns = (KernelResult -> ExpReturns -> m ExpReturns)
-> [KernelResult] -> [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM KernelResult -> ExpReturns -> m ExpReturns
forall (m :: * -> *) rep inner.
(Monad m, HasScope rep m, HasLetDecMem (LetDec rep), ASTRep rep,
 OpReturns inner, FParamInfo rep ~ FParamMem,
 LParamInfo rep ~ LParamMem, RetType rep ~ RetTypeMem,
 BranchType rep ~ BranchTypeMem, Op rep ~ MemOp inner) =>
KernelResult -> ExpReturns -> m ExpReturns
correct ([KernelResult] -> [ExpReturns] -> m [ExpReturns])
-> (KernelBody somerep -> [KernelResult])
-> KernelBody somerep
-> [ExpReturns]
-> m [ExpReturns]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. KernelBody somerep -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult
  where
    correct :: KernelResult -> ExpReturns -> m ExpReturns
correct (WriteReturns Certs
_ Shape
_ VName
arr [(Slice SubExp, SubExp)]
_) ExpReturns
_ = VName -> m ExpReturns
forall rep (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns VName
arr
    correct KernelResult
_ ExpReturns
ret = ExpReturns -> m ExpReturns
forall (f :: * -> *) a. Applicative f => a -> f a
pure ExpReturns
ret

-- | Like 'segOpType', but for memory representations.
segOpReturns ::
  (Mem rep inner, Monad m, HasScope rep m) =>
  SegOp lvl somerep ->
  m [ExpReturns]
segOpReturns :: SegOp lvl somerep -> m [ExpReturns]
segOpReturns k :: SegOp lvl somerep
k@(SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody somerep
kbody) =
  KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
forall rep inner (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody somerep
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl somerep -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl somerep
k
segOpReturns k :: SegOp lvl somerep
k@(SegRed lvl
_ SegSpace
_ [SegBinOp somerep]
_ [Type]
_ KernelBody somerep
kbody) =
  KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
forall rep inner (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody somerep
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl somerep -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl somerep
k
segOpReturns k :: SegOp lvl somerep
k@(SegScan lvl
_ SegSpace
_ [SegBinOp somerep]
_ [Type]
_ KernelBody somerep
kbody) =
  KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
forall rep inner (m :: * -> *) somerep.
(Mem rep inner, HasScope rep m, Monad m) =>
KernelBody somerep -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody somerep
kbody ([ExpReturns] -> m [ExpReturns])
-> ([ExtType] -> [ExpReturns]) -> [ExtType] -> m [ExpReturns]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> m [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOp lvl somerep -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl somerep
k
segOpReturns (SegHist lvl
_ SegSpace
_ [HistOp somerep]
ops [Type]
_ KernelBody somerep
_) =
  [[ExpReturns]] -> [ExpReturns]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[ExpReturns]] -> [ExpReturns])
-> m [[ExpReturns]] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp somerep -> m [ExpReturns])
-> [HistOp somerep] -> m [[ExpReturns]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VName -> m ExpReturns) -> [VName] -> m [ExpReturns]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m ExpReturns
forall rep (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns ([VName] -> m [ExpReturns])
-> (HistOp somerep -> [VName]) -> HistOp somerep -> m [ExpReturns]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp somerep -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp somerep]
ops